Unverified Commit c0c12e20 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Support Flax sharding constraints (#1933)



* Support flax sharding constraints
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add warning for deprecated TE logical axes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update examples
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 6c526794
......@@ -264,8 +264,10 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, nn_partitioning.axis_rules(
((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
......@@ -276,22 +278,21 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
# Get the base axis rules and extend them with TE's rules.
axis_rules = nn_partitioning.get_axis_rules()
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")
logical_partition_spec = nn.get_partition_spec(abs_var_collect)
# Note that `nn.logical_to_mesh_sharding` returns a dict with an extra
......
......@@ -259,7 +259,13 @@ def train_and_evaluate(args):
fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh:
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
......@@ -270,17 +276,14 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
# Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast
sharding_rules = te_flax.extend_logical_axis_rules(tuple())
with flax.linen.logical_axis_rules(sharding_rules):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te_flax.extend_logical_axis_rules(tuple())
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
......
......@@ -379,8 +379,11 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh:
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
......@@ -390,18 +393,18 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
# Create custom Flax logical axis rules for sharding.
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
# Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast.
sharding_rules = te_flax.extend_logical_axis_rules(customized_rules)
with flax.linen.logical_axis_rules(sharding_rules):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......
......@@ -180,8 +180,9 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# (b, h, q, k): Last two axes are always replicated
attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
)
# When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
......
......@@ -14,6 +14,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import warnings
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
......@@ -117,7 +118,9 @@ def with_sharding_constraint_by_logical_axes(
x: jnp.array, logical_axis_names: Optional[tuple | list]
):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
A wrapper function to flax.linen.with_logical_constraint.
DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future.
If logical_axis_names = None, this means no sharding constraint is applied.
......@@ -133,6 +136,28 @@ def with_sharding_constraint_by_logical_axes(
if not logical_axis_names:
return x
try:
# Check if Flax logical axis rules are available, if so use them
import flax
flax_rules = flax.linen.get_logical_axis_rules()
if len(flax_rules) > 0:
return flax.linen.with_logical_constraint(
x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT
)
except ImportError:
pass
warnings.warn(
"TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and"
" will be removed in a future version. Please use Flax logical axes with a"
" flax.linen.logical_axis_rules context and optionally use"
" transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your"
" rules.",
DeprecationWarning,
)
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment