Unverified Commit 309c6d49 authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

Jax example cleanup and replace pjit with jit. (#1107)



* Use jit instead of pjit

---------
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent a3353744
# Basic Transformer Encoder Example with Optional FP8 # # Basic Transformer Encoder Example with Optional FP8 #
This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `pjit` to set up multiple GPU training. The basic pjit usage can be referred to [Scale up Flax Modules on multiple devices with pjit](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html). This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `jit` `in `in_shardings` and `out_shardings` parameters to set up multiple GPU training. The basic parallel jit usage can be referred to [Scale up Flax Modules on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html).
## Single GPU ## ## Single GPU ##
...@@ -31,11 +31,11 @@ python test_single_gpu_encoder.py --use-fp8 ...@@ -31,11 +31,11 @@ python test_single_gpu_encoder.py --use-fp8
4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis. 4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.
5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for pjit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example. 5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example.
6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding. 6. Fill in `params_sharding` and `encoder.init` to jit to get a compiled function, `jit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding.
7. The `train_step` and `eval_step` also need to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example. 7. The `train_step` and `eval_step` also need to be compiled by jit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.
8. Use `CUDA_VISIBLE_DEVICES` to control the number of GPUs used. For example, if the system has 8 GPUs but only 4 GPUs need to be used, then: 8. Use `CUDA_VISIBLE_DEVICES` to control the number of GPUs used. For example, if the system has 8 GPUs but only 4 GPUs need to be used, then:
```sh ```sh
...@@ -84,7 +84,7 @@ python test_model_parallel_encoder.py --use-fp8 ...@@ -84,7 +84,7 @@ python test_model_parallel_encoder.py --use-fp8
1. This example inherits previous model parallelism example, but uses multiprocessing instead of single-program multiple-data (SPMD). It uses 1 GPU per process. 1. This example inherits previous model parallelism example, but uses multiprocessing instead of single-program multiple-data (SPMD). It uses 1 GPU per process.
2. The benefit of multiprocessing is to setup hardware affinity for GPUs, such as NUMA binding. It may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details. 2. There is two main benefits of multiprocessing: support multi-node and to setup hardware affinity for GPUs, such as NUMA binding. Affinity may help improve performance and stability. Please refer to [Best Practices When Benchmarking CUDA Applications](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2019-s9956/) for more details.
3. The quick way to check system topology is to use `nvidia-smi`, for example: 3. The quick way to check system topology is to use `nvidia-smi`, for example:
```sh ```sh
......
...@@ -17,7 +17,7 @@ from flax import linen as nn ...@@ -17,7 +17,7 @@ from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from flax.training import train_state from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
...@@ -223,32 +223,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ...@@ -223,32 +223,36 @@ def check_fp8(state, var_collect, inputs, masks, labels):
) )
def get_params_pspec(sharding_rules, abs_var_collect): def get_params_sharding(sharding_rules, abs_var_collect, mesh):
"""Refer params to create params partition spec""" """Refer params to create params sharding"""
rules_dict = {} rules_dict = dict(sharding_rules)
for key, value in sharding_rules:
rules_dict[key] = value
def to_device_axis(logical_axis): def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis] partitions = [rules_dict[key] for key in logical_axis]
return jax.sharding.PartitionSpec(*partitions) return NamedSharding(mesh, PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) params_axes_sharding = jax.tree_util.tree_map(
params_axes_pspec = flax.core.unfreeze(params_axes_pspec) to_device_axis, nn_partitioning.get_axis_names(params_axes)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) )
params_pspec = {**params_pspec, **params_axes_pspec} params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
return params_pspec params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
def get_state_pspec(state, params_pspec): def get_state_sharding(state, params_sharding):
"""Refer params_pspec to create state partition spec""" """Refer params_sharding to create state sharding"""
def replace_params(x): def replace_params(x):
return params_pspec if isinstance(x, dict) else None return params_sharding if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) state_sharding = jax.tree_util.tree_map(
return state_pspec replace_params, state, is_leaf=lambda x: isinstance(x, dict)
)
return state_sharding
def train_and_evaluate(args): def train_and_evaluate(args):
...@@ -270,7 +274,9 @@ def train_and_evaluate(args): ...@@ -270,7 +274,9 @@ def train_and_evaluate(args):
), f"Test batch size needs to be multiple of {num_gpu_dp}" ), f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)): with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh:
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
...@@ -291,34 +297,39 @@ def train_and_evaluate(args): ...@@ -291,34 +297,39 @@ def train_and_evaluate(args):
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
in_shardings = (None, inputs_pspec, masks_pspec) in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = { out_shardings = {
key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
} }
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create( state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer apply_fn=encoder.apply, params=params, tx=optimizer
) )
state_pspec = get_state_pspec(state, params_pspec) state_sharding = get_state_sharding(state, params_sharding)
labels_pspec = jax.sharding.PartitionSpec( labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))
DEVICE_DP_AXIS,
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
) )
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None) out_shardings = (None, None)
pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
...@@ -327,7 +338,7 @@ def train_and_evaluate(args): ...@@ -327,7 +338,7 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs) jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
return None return None
...@@ -337,11 +348,11 @@ def train_and_evaluate(args): ...@@ -337,11 +348,11 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
) )
test_loss, test_accuracy = eval_model( test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, pjit_eval_step state, test_ds, args.test_batch_size, var_collect, jit_eval_step
) )
print( print(
......
...@@ -17,7 +17,7 @@ from flax import linen as nn ...@@ -17,7 +17,7 @@ from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from flax.training import train_state from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
...@@ -202,32 +202,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ...@@ -202,32 +202,36 @@ def check_fp8(state, var_collect, inputs, masks, labels):
) )
def get_params_pspec(sharding_rules, abs_var_collect): def get_params_sharding(sharding_rules, abs_var_collect, mesh):
"""Refer params to create params partition spec""" """Refer params to create params sharding"""
rules_dict = {} rules_dict = dict(sharding_rules)
for key, value in sharding_rules:
rules_dict[key] = value
def to_device_axis(logical_axis): def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis] partitions = [rules_dict[key] for key in logical_axis]
return jax.sharding.PartitionSpec(*partitions) return NamedSharding(mesh, PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) params_axes_sharding = jax.tree_util.tree_map(
params_axes_pspec = flax.core.unfreeze(params_axes_pspec) to_device_axis, nn_partitioning.get_axis_names(params_axes)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) )
params_pspec = {**params_pspec, **params_axes_pspec} params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
return params_pspec params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
def get_state_pspec(state, params_pspec): def get_state_sharding(state, params_sharding):
"""Refer params_pspec to create state partition spec""" """Refer params_sharding to create state sharding"""
def replace_params(x): def replace_params(x):
return params_pspec if isinstance(x, dict) else None return params_sharding if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) state_sharding = jax.tree_util.tree_map(
return state_pspec replace_params, state, is_leaf=lambda x: isinstance(x, dict)
)
return state_sharding
def train_and_evaluate(args): def train_and_evaluate(args):
...@@ -240,7 +244,7 @@ def train_and_evaluate(args): ...@@ -240,7 +244,7 @@ def train_and_evaluate(args):
assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}" assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
device_mesh = mesh_utils.create_device_mesh((num_gpu,)) device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)): with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh:
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
...@@ -260,34 +264,43 @@ def train_and_evaluate(args): ...@@ -260,34 +264,43 @@ def train_and_evaluate(args):
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) sharding_rules = te_flax.extend_logical_axis_rules(tuple())
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
in_shardings = (None, inputs_pspec, masks_pspec) in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = { out_shardings = {
key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
} }
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create( state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer apply_fn=encoder.apply, params=params, tx=optimizer
) )
state_pspec = get_state_pspec(state, params_pspec) state_sharding = get_state_sharding(state, params_sharding)
labels_pspec = jax.sharding.PartitionSpec( labels_sharding = NamedSharding(
DEVICE_DP_AXIS, mesh,
PartitionSpec(
DEVICE_DP_AXIS,
),
) )
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None) out_shardings = (None, None)
pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
...@@ -296,7 +309,7 @@ def train_and_evaluate(args): ...@@ -296,7 +309,7 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs) jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
return None return None
...@@ -306,11 +319,11 @@ def train_and_evaluate(args): ...@@ -306,11 +319,11 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
) )
test_loss, test_accuracy = eval_model( test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, pjit_eval_step state, test_ds, args.test_batch_size, var_collect, jit_eval_step
) )
print( print(
......
...@@ -19,7 +19,7 @@ from flax import linen as nn ...@@ -19,7 +19,7 @@ from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from flax.training import train_state from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
...@@ -305,32 +305,36 @@ def check_fp8(state, var_collect, inputs, masks, labels): ...@@ -305,32 +305,36 @@ def check_fp8(state, var_collect, inputs, masks, labels):
) )
def get_params_pspec(sharding_rules, abs_var_collect): def get_params_sharding(sharding_rules, abs_var_collect, mesh):
"""Refer params to create params partition spec""" """Refer params to create params sharding"""
rules_dict = {} rules_dict = dict(sharding_rules)
for key, value in sharding_rules:
rules_dict[key] = value
def to_device_axis(logical_axis): def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis] partitions = [rules_dict[key] for key in logical_axis]
return jax.sharding.PartitionSpec(*partitions) return NamedSharding(mesh, jax.sharding.PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {}) params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn_partitioning.get_axis_names(params_axes)) params_axes_sharding = jax.tree_util.tree_map(
params_axes_pspec = flax.core.unfreeze(params_axes_pspec) to_device_axis, nn_partitioning.get_axis_names(params_axes)
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY]) )
params_pspec = {**params_pspec, **params_axes_pspec} params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
return params_pspec params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
def get_state_pspec(state, params_pspec): def get_state_sharding(state, params_sharding):
"""Refer params_pspec to create state partition spec""" """Refer params_sharding to create state sharding"""
def replace_params(x): def replace_params(x):
return params_pspec if isinstance(x, dict) else None return params_sharding if isinstance(x, dict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, dict)) state_sharding = jax.tree_util.tree_map(
return state_pspec replace_params, state, is_leaf=lambda x: isinstance(x, dict)
)
return state_sharding
def train_and_evaluate(args): def train_and_evaluate(args):
...@@ -362,7 +366,7 @@ def train_and_evaluate(args): ...@@ -362,7 +366,7 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh( with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as shard_mesh: ) as mesh:
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
...@@ -383,34 +387,41 @@ def train_and_evaluate(args): ...@@ -383,34 +387,41 @@ def train_and_evaluate(args):
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec) inputs_sharding = NamedSharding(mesh, inputs_pspec)
masks_sharding = NamedSharding(mesh, masks_pspec)
in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = { out_shardings = {
key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
} }
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create( state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer apply_fn=encoder.apply, params=params, tx=optimizer
) )
state_pspec = get_state_pspec(state, params_pspec) state_sharding = get_state_sharding(state, params_sharding)
labels_pspec = jax.sharding.PartitionSpec( labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))
DEVICE_DP_AXIS,
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
) )
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None) out_shardings = (None, None)
pjit_eval_step = pjit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
...@@ -419,7 +430,7 @@ def train_and_evaluate(args): ...@@ -419,7 +430,7 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs) jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
else: else:
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
...@@ -433,11 +444,11 @@ def train_and_evaluate(args): ...@@ -433,11 +444,11 @@ def train_and_evaluate(args):
args.batch_size, args.batch_size,
rngs, rngs,
var_collect, var_collect,
pjit_train_step, jit_train_step,
shard_mesh, mesh,
inputs_pspec, inputs_pspec,
masks_pspec, masks_pspec,
labels_pspec, labels_sharding.spec,
) )
test_loss, test_accuracy = eval_model( test_loss, test_accuracy = eval_model(
...@@ -445,11 +456,11 @@ def train_and_evaluate(args): ...@@ -445,11 +456,11 @@ def train_and_evaluate(args):
test_ds, test_ds,
args.test_batch_size, args.test_batch_size,
var_collect, var_collect,
pjit_eval_step, jit_eval_step,
shard_mesh, mesh,
inputs_pspec, inputs_pspec,
masks_pspec, masks_pspec,
labels_pspec, labels_sharding.spec,
) )
if args.process_id == 0: if args.process_id == 0:
print( print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment