Unverified Commit bbddcb92 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Cleanup the MLP warning for TE GEMM + TP (#2054)



* fix pspec check
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cleaning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add docstring
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* use dict.get()
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix lint
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8dc2756e
...@@ -15,7 +15,6 @@ quantization, and distributed training through sharding constraints. ...@@ -15,7 +15,6 @@ quantization, and distributed training through sharding constraints.
from typing import List, Tuple, Sequence, Union, Callable from typing import List, Tuple, Sequence, Union, Callable
from functools import partial from functools import partial
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -93,28 +92,6 @@ def layernorm_mlp( ...@@ -93,28 +92,6 @@ def layernorm_mlp(
""" """
assert len(kernels) == 2 assert len(kernels) == 2
# For MaxText TP (= Megatron TP + sharding in hidden dimension of remaining unsharded
# activations), JAX dot_general may perform better then TE GEMM custom call
# This inspection only works if either norm_input_axes or dot_1_input_axes is set
is_mxfp8 = (
False
if quantizer_sets[0] == noop_quantizer_set
else quantizer_sets[0].x.scaling_mode.is_1d_block_scaling()
)
inspect_axes = norm_input_axes or dot_1_input_axes
if (
inspect_axes is not None
and len(inspect_axes) == x.ndim
and inspect_axes[-1] is not None
and not is_mxfp8
):
warnings.warn(
"Detected sharding in the hidden dimension of the MLP activation input. For improved"
" performance, consider using JAX’s built-in `dot_general` implementation. To try"
" this, set the environment variable: `NVTE_JAX_CUSTOM_CALLS='GemmPrimitive=false'`",
UserWarning,
)
kernel_1 = kernels[0] kernel_1 = kernels[0]
kernel_2 = kernels[1] kernel_2 = kernels[1]
bias_1 = biases[0] bias_1 = biases[0]
......
...@@ -86,37 +86,20 @@ def get_sharding_map_logic_axis_to_mesh_axis(): ...@@ -86,37 +86,20 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return te_logical_axis_to_mesh_axis return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False): def _generate_pspec(logical_axis_names):
""" """
Convert logical axes to PartitionSpec Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec.
Note, this method does not support Flax logical axes.
Args:
logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec.
Returns:
A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names
""" """
rules = None rules = get_sharding_map_logic_axis_to_mesh_axis()
if with_flax_rules:
try: mesh_axis_names = [rules.get(name) for name in logical_axis_names]
import flax
rules = dict(flax.linen.get_logical_axis_rules())
except ImportError:
pass
if rules is None:
warnings.warn(
"Transformer Engine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated"
" and removed in a future version. Please use Flax logical axes with the"
" `flax.linen.logical_axis_rules()` context and optionally use"
" `transformer_engine.jax.flax.extend_logical_axis_rules()` to extend Flax axis rules"
" with Transformer Engine logical axes.",
DeprecationWarning,
)
rules = get_sharding_map_logic_axis_to_mesh_axis()
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = []
for name in logical_axis_names:
axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name)
pspec = jax.sharding.PartitionSpec(*mesh_axis_names) pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
if padded:
pspec = get_padded_spec(pspec, len(mesh_axis_names))
return pspec return pspec
...@@ -188,7 +171,7 @@ def with_sharding_constraint_by_logical_axes( ...@@ -188,7 +171,7 @@ def with_sharding_constraint_by_logical_axes(
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert len(x.shape) == len(logical_axis_names) assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names) pspec = _generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec) return with_sharding_constraint(x, pspec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment