Unverified Commit bbddcb92 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Cleanup the MLP warning for TE GEMM + TP (#2054)



* fix pspec check
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cleaning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add docstring
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* use dict.get()
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix lint
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8dc2756e
......@@ -15,7 +15,6 @@ quantization, and distributed training through sharding constraints.
from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
import warnings
import jax
import jax.numpy as jnp
......@@ -93,28 +92,6 @@ def layernorm_mlp(
"""
assert len(kernels) == 2
# For MaxText TP (= Megatron TP + sharding in hidden dimension of remaining unsharded
# activations), JAX dot_general may perform better then TE GEMM custom call
# This inspection only works if either norm_input_axes or dot_1_input_axes is set
is_mxfp8 = (
False
if quantizer_sets[0] == noop_quantizer_set
else quantizer_sets[0].x.scaling_mode.is_1d_block_scaling()
)
inspect_axes = norm_input_axes or dot_1_input_axes
if (
inspect_axes is not None
and len(inspect_axes) == x.ndim
and inspect_axes[-1] is not None
and not is_mxfp8
):
warnings.warn(
"Detected sharding in the hidden dimension of the MLP activation input. For improved"
" performance, consider using JAX’s built-in `dot_general` implementation. To try"
" this, set the environment variable: `NVTE_JAX_CUSTOM_CALLS='GemmPrimitive=false'`",
UserWarning,
)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
bias_1 = biases[0]
......
......@@ -86,37 +86,20 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False):
def _generate_pspec(logical_axis_names):
"""
Convert logical axes to PartitionSpec
"""
rules = None
if with_flax_rules:
try:
import flax
Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec.
Note, this method does not support Flax logical axes.
rules = dict(flax.linen.get_logical_axis_rules())
except ImportError:
pass
if rules is None:
warnings.warn(
"Transformer Engine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated"
" and removed in a future version. Please use Flax logical axes with the"
" `flax.linen.logical_axis_rules()` context and optionally use"
" `transformer_engine.jax.flax.extend_logical_axis_rules()` to extend Flax axis rules"
" with Transformer Engine logical axes.",
DeprecationWarning,
)
Args:
logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec.
Returns:
A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names
"""
rules = get_sharding_map_logic_axis_to_mesh_axis()
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = []
for name in logical_axis_names:
axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name)
mesh_axis_names = [rules.get(name) for name in logical_axis_names]
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
if padded:
pspec = get_padded_spec(pspec, len(mesh_axis_names))
return pspec
......@@ -188,7 +171,7 @@ def with_sharding_constraint_by_logical_axes(
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names)
pspec = _generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment