Unverified Commit 235c8d00 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Enable TE GEMM custom call for all recipes (#2047)



* enabled TE GEMM for all recipes
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add warnings
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b6b3abce
...@@ -34,7 +34,7 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -34,7 +34,7 @@ class BasePrimitive(metaclass=ABCMeta):
_is_enabled = True _is_enabled = True
# Default list of primitives to disable for all recipes # Default list of primitives to disable for all recipes
_default_disable_names = ["GemmPrimitive"] _default_disable_names = []
@classmethod @classmethod
def enabled(cls): def enabled(cls):
......
...@@ -15,6 +15,7 @@ quantization, and distributed training through sharding constraints. ...@@ -15,6 +15,7 @@ quantization, and distributed training through sharding constraints.
from typing import List, Tuple, Sequence, Union, Callable from typing import List, Tuple, Sequence, Union, Callable
from functools import partial from functools import partial
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -92,6 +93,28 @@ def layernorm_mlp( ...@@ -92,6 +93,28 @@ def layernorm_mlp(
""" """
assert len(kernels) == 2 assert len(kernels) == 2
# For MaxText TP (= Megatron TP + sharding in hidden dimension of remaining unsharded
# activations), JAX dot_general may perform better then TE GEMM custom call
# This inspection only works if either norm_input_axes or dot_1_input_axes is set
is_mxfp8 = (
False
if quantizer_sets[0] == noop_quantizer_set
else quantizer_sets[0].x.scaling_mode.is_1d_block_scaling()
)
inspect_axes = norm_input_axes or dot_1_input_axes
if (
inspect_axes is not None
and len(inspect_axes) == x.ndim
and inspect_axes[-1] is not None
and not is_mxfp8
):
warnings.warn(
"Detected sharding in the hidden dimension of the MLP activation input. For improved"
" performance, consider using JAX’s built-in `dot_general` implementation. To try"
" this, set the environment variable: `NVTE_JAX_CUSTOM_CALLS='GemmPrimitive=false'`",
UserWarning,
)
kernel_1 = kernels[0] kernel_1 = kernels[0]
kernel_2 = kernels[1] kernel_2 = kernels[1]
bias_1 = biases[0] bias_1 = biases[0]
......
...@@ -352,9 +352,6 @@ class BlockScalingQuantizeConfig: ...@@ -352,9 +352,6 @@ class BlockScalingQuantizeConfig:
cls.initialize(fp8_recipe) cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0 cls.AMAX_HISTORY_LEN = 0
# Use TE GEMM instead of JAX GEMM for better performance
tex.base.manage_primitives(enable_names=["GemmPrimitive"])
@staticmethod @staticmethod
def finalize() -> None: def finalize() -> None:
"""Reset the block scaling configuration.""" """Reset the block scaling configuration."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment