Unverified Commit 097afc00 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

fix model parallel encoder to be properly sharded params (#1794)



* fix model parallel encoder to be properly sharded
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cd11e00d
......@@ -4,9 +4,12 @@
"""Shared functions for the encoder tests"""
from functools import lru_cache
import jax
import jax.numpy
import transformer_engine
from transformer_engine_jax import get_device_compute_capability
from transformer_engine.common import recipe
import numpy as np
@lru_cache
......@@ -30,6 +33,71 @@ def is_mxfp8_supported():
return gpu_arch >= 100
def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False):
"""Checks whether most params are sharded across sharding axis.
(Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/315e551e5942b24656a4250dcfca986fb4135b72/MaxText/maxtext_utils.py#L348)
This function determines whether the majority of parameters are distributed
across a specified sharding axes with an acceptable tolerance. It compares the
current distribution to a scenario where all parameters are fully sharded
across the axes on which the params are sharded e.g. 'tensor' axis.
Args:
params: params of the model state
mesh: mesh constructed from config
tolerance: float between 0.0 and 1.0 representing the allowed percentage of
non-sharded parameters.
"""
def get_product_num_devices_for_weight_sharding(weight_sharding_axes):
product_num_devices_for_weight_sharding = 1
for axis in weight_sharding_axes:
product_num_devices_for_weight_sharding *= mesh.shape.get(axis, 1)
return product_num_devices_for_weight_sharding
def assert_leaf_sharding(path, arr):
# Is the weight sharded? Get the axes on which it is sharded.
partition_spec = arr.sharding.spec
weight_sharding_axes = set(partition_spec) - set([None]) # None is not a sharding axis
# Total number of devices on the axes on which the weight is sharded.
product_num_devices_for_weight_sharding = get_product_num_devices_for_weight_sharding(
weight_sharding_axes
)
# Params present in one shard (on one device).
shard = arr.addressable_shards[0]
params_per_chip = np.prod(shard.data.shape)
# Total number of params (across all devicess).
total_params = jax.numpy.size(arr)
# Percentage of params that are unsharded.
unsharded_perc = (
(params_per_chip / (total_params / product_num_devices_for_weight_sharding) - 1) * 100
if params_per_chip < total_params
else 100
)
if print_info:
print(
f"{path}: {unsharded_perc:.2f}% unsharded, unsharded param shape={arr.shape},"
f" partition spec={partition_spec}"
)
# If the weight is sharded on any axis, then the percentage of
# unsharded params should be less than the tolerance.
assert (
product_num_devices_for_weight_sharding == 1 or unsharded_perc < tolerance
), f"{path}: {unsharded_perc:.2f}% unsharded"
jax.tree_util.tree_map_with_path(
lambda p, x: assert_leaf_sharding("/".join(str(x) for x in p), x), params
)
def get_fp8_recipe_from_name_string(name: str):
"""Query recipe from a given name string"""
match name:
......
......@@ -19,7 +19,11 @@ from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string
from common import (
is_bf16_supported,
get_fp8_recipe_from_name_string,
assert_params_sufficiently_sharded,
)
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
......@@ -223,38 +227,6 @@ def check_fp8(state, var_collect, inputs, masks, labels):
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
"""Refer params to create params sharding"""
rules_dict = dict(sharding_rules)
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
return NamedSharding(mesh, PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_sharding = jax.tree_util.tree_map(
to_device_axis, nn_partitioning.get_axis_names(params_axes)
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
def get_state_sharding(state, params_sharding):
"""Refer params_sharding to create state sharding"""
def replace_params(x):
return params_sharding if isinstance(x, dict) else None
state_sharding = jax.tree_util.tree_map(
replace_params, state, is_leaf=lambda x: isinstance(x, dict)
)
return state_sharding
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
......@@ -291,8 +263,9 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh:
) as mesh, nn_partitioning.axis_rules(
((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
......@@ -312,25 +285,54 @@ def train_and_evaluate(args):
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
# Get the base axis rules and extend them with TE's rules.
axis_rules = nn_partitioning.get_axis_rules()
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")
logical_partition_spec = nn.get_partition_spec(abs_var_collect)
# Note that `nn.logical_to_mesh_sharding` returns a dict with an extra
# "params" key that contains the sharding for the parameters.
params_sharding = nn.logical_to_mesh_sharding(
logical_partition_spec, mesh, te_extended_axis_rules
)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None
for key in abs_var_collect
}
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
var_collect = jit_encoder_init(init_rngs, inputs, masks)
# Check if params are sufficiently sharded after initialization
assert_params_sufficiently_sharded(var_collect, mesh, print_info=False)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_sharding = get_state_sharding(state, params_sharding)
abs_state = jax.eval_shape(
lambda: train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
)
logical_state_partition_spec = nn.get_partition_spec(abs_state)
state_sharding = nn.logical_to_mesh_sharding(
logical_state_partition_spec, mesh, te_extended_axis_rules
)
# Check if params are sufficiently sharded after jitting the state creation
assert_params_sufficiently_sharded(state.params, mesh, print_info=False)
# state_sharding = get_state_sharding(state, params_sharding)
labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))
in_shardings = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment