Unverified Commit c0c12e20 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Support Flax sharding constraints (#1933)



* Support flax sharding constraints
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add warning for deprecated TE logical axes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update examples
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 6c526794
...@@ -264,8 +264,10 @@ def train_and_evaluate(args): ...@@ -264,8 +264,10 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh( with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, nn_partitioning.axis_rules( ) as mesh, te.fp8_autocast(
((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
): ):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
...@@ -276,22 +278,21 @@ def train_and_evaluate(args): ...@@ -276,22 +278,21 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast( # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
enabled=args.use_fp8, axis_rules = flax.linen.get_logical_axis_rules()
fp8_recipe=fp8_recipe, axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
):
with flax.linen.logical_axis_rules(te_extended_axis_rules):
print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")
encoder = Net(num_embed, args.enable_sp) encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
# Get the base axis rules and extend them with TE's rules.
axis_rules = nn_partitioning.get_axis_rules()
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")
logical_partition_spec = nn.get_partition_spec(abs_var_collect) logical_partition_spec = nn.get_partition_spec(abs_var_collect)
# Note that `nn.logical_to_mesh_sharding` returns a dict with an extra # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra
......
...@@ -259,7 +259,13 @@ def train_and_evaluate(args): ...@@ -259,7 +259,13 @@ def train_and_evaluate(args):
fp8_recipe = None fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu,)) device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh: with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
...@@ -270,17 +276,14 @@ def train_and_evaluate(args): ...@@ -270,17 +276,14 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast( # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast
enabled=args.use_fp8, sharding_rules = te_flax.extend_logical_axis_rules(tuple())
fp8_recipe=fp8_recipe, with flax.linen.logical_axis_rules(sharding_rules):
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te_flax.extend_logical_axis_rules(tuple())
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
......
...@@ -379,8 +379,11 @@ def train_and_evaluate(args): ...@@ -379,8 +379,11 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh( with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh: ) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
...@@ -390,18 +393,18 @@ def train_and_evaluate(args): ...@@ -390,18 +393,18 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast( # Create custom Flax logical axis rules for sharding.
enabled=args.use_fp8, customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
fp8_recipe=fp8_recipe, # Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast.
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), sharding_rules = te_flax.extend_logical_axis_rules(customized_rules)
):
with flax.linen.logical_axis_rules(sharding_rules):
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
......
...@@ -180,8 +180,9 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -180,8 +180,9 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# (b, h, q, k): Last two axes are always replicated
attn_weights = with_sharding_constraint_by_logical_axes( attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
) )
# When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
......
...@@ -14,6 +14,7 @@ from contextlib import contextmanager ...@@ -14,6 +14,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import warnings
from jax.interpreters import pxla from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -117,7 +118,9 @@ def with_sharding_constraint_by_logical_axes( ...@@ -117,7 +118,9 @@ def with_sharding_constraint_by_logical_axes(
x: jnp.array, logical_axis_names: Optional[tuple | list] x: jnp.array, logical_axis_names: Optional[tuple | list]
): ):
""" """
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. A wrapper function to flax.linen.with_logical_constraint.
DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future.
If logical_axis_names = None, this means no sharding constraint is applied. If logical_axis_names = None, this means no sharding constraint is applied.
...@@ -133,6 +136,28 @@ def with_sharding_constraint_by_logical_axes( ...@@ -133,6 +136,28 @@ def with_sharding_constraint_by_logical_axes(
if not logical_axis_names: if not logical_axis_names:
return x return x
try:
# Check if Flax logical axis rules are available, if so use them
import flax
flax_rules = flax.linen.get_logical_axis_rules()
if len(flax_rules) > 0:
return flax.linen.with_logical_constraint(
x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT
)
except ImportError:
pass
warnings.warn(
"TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and"
" will be removed in a future version. Please use Flax logical axes with a"
" flax.linen.logical_axis_rules context and optionally use"
" transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your"
" rules.",
DeprecationWarning,
)
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert len(x.shape) == len(logical_axis_names) assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names) pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec) return with_sharding_constraint(x, pspec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment