"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "f04908cae782e1a2404eb3e4f331718d311d1e0d"
Unverified Commit 5b16807c authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Use keyword args for jit in_shardings and out_shardings (#1898)



Use keyword args for jit in_shardings and out_shardings
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 0587ecf4
......@@ -307,7 +307,9 @@ def train_and_evaluate(args):
key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None
for key in abs_var_collect
}
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks)
# Check if params are sufficiently sharded after initialization
......@@ -344,11 +346,15 @@ def train_and_evaluate(args):
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
......
......@@ -288,7 +288,9 @@ def train_and_evaluate(args):
out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
......@@ -312,11 +314,15 @@ def train_and_evaluate(args):
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
......
......@@ -412,7 +412,9 @@ def train_and_evaluate(args):
out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
......@@ -432,11 +434,15 @@ def train_and_evaluate(args):
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment