Unverified Commit 5b16807c authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Use keyword args for jit in_shardings and out_shardings (#1898)



Use keyword args for jit in_shardings and out_shardings
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 0587ecf4
...@@ -307,7 +307,9 @@ def train_and_evaluate(args): ...@@ -307,7 +307,9 @@ def train_and_evaluate(args):
key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None
for key in abs_var_collect for key in abs_var_collect
} }
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
# Check if params are sufficiently sharded after initialization # Check if params are sufficiently sharded after initialization
...@@ -344,11 +346,15 @@ def train_and_evaluate(args): ...@@ -344,11 +346,15 @@ def train_and_evaluate(args):
None, None,
) )
out_shardings = (state_sharding, None, None, None) out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings) jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
......
...@@ -288,7 +288,9 @@ def train_and_evaluate(args): ...@@ -288,7 +288,9 @@ def train_and_evaluate(args):
out_shardings = { out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
} }
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
...@@ -312,11 +314,15 @@ def train_and_evaluate(args): ...@@ -312,11 +314,15 @@ def train_and_evaluate(args):
None, None,
) )
out_shardings = (state_sharding, None, None, None) out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings) jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
......
...@@ -412,7 +412,9 @@ def train_and_evaluate(args): ...@@ -412,7 +412,9 @@ def train_and_evaluate(args):
out_shardings = { out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
} }
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
...@@ -432,11 +434,15 @@ def train_and_evaluate(args): ...@@ -432,11 +434,15 @@ def train_and_evaluate(args):
None, None,
) )
out_shardings = (state_sharding, None, None, None) out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings) jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment