Unverified Commit 235c8d00 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Enable TE GEMM custom call for all recipes (#2047)



* enabled TE GEMM for all recipes
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add warnings
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b6b3abce
......@@ -34,7 +34,7 @@ class BasePrimitive(metaclass=ABCMeta):
_is_enabled = True
# Default list of primitives to disable for all recipes
_default_disable_names = ["GemmPrimitive"]
_default_disable_names = []
@classmethod
def enabled(cls):
......
......@@ -15,6 +15,7 @@ quantization, and distributed training through sharding constraints.
from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
import warnings
import jax
import jax.numpy as jnp
......@@ -92,6 +93,28 @@ def layernorm_mlp(
"""
assert len(kernels) == 2
# For MaxText TP (= Megatron TP + sharding in hidden dimension of remaining unsharded
# activations), JAX dot_general may perform better then TE GEMM custom call
# This inspection only works if either norm_input_axes or dot_1_input_axes is set
is_mxfp8 = (
False
if quantizer_sets[0] == noop_quantizer_set
else quantizer_sets[0].x.scaling_mode.is_1d_block_scaling()
)
inspect_axes = norm_input_axes or dot_1_input_axes
if (
inspect_axes is not None
and len(inspect_axes) == x.ndim
and inspect_axes[-1] is not None
and not is_mxfp8
):
warnings.warn(
"Detected sharding in the hidden dimension of the MLP activation input. For improved"
" performance, consider using JAX’s built-in `dot_general` implementation. To try"
" this, set the environment variable: `NVTE_JAX_CUSTOM_CALLS='GemmPrimitive=false'`",
UserWarning,
)
kernel_1 = kernels[0]
kernel_2 = kernels[1]
bias_1 = biases[0]
......
......@@ -352,9 +352,6 @@ class BlockScalingQuantizeConfig:
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
# Use TE GEMM instead of JAX GEMM for better performance
tex.base.manage_primitives(enable_names=["GemmPrimitive"])
@staticmethod
def finalize() -> None:
"""Reset the block scaling configuration."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment