Unverified Commit 6c579267 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Allow enabling partial custom calls through the environment variable (#1007)



* Add enabled() to BasePrimitive

* Add layernorm/rmsnorm fallback

* Add cast_fp8 fallback

* Add transpose/cast_transpose XLA fall back

* Act_lu fallback

* Add transpose fallback

* Add softmax fallback

* Unify the use of _cast_fp8

* Add tests for NVTE_JAX_CUSTOM_CALLS_RE

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 210e57de
...@@ -9,6 +9,9 @@ pip install pytest==8.2.1 ...@@ -9,6 +9,9 @@ pip install pytest==8.2.1
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
......
...@@ -19,8 +19,10 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti ...@@ -19,8 +19,10 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
(32, 32, 32), (32, 32, 32),
...@@ -34,21 +36,6 @@ DTYPES = [jnp.bfloat16, jnp.float32] ...@@ -34,21 +36,6 @@ DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "quick_gelu":
return lambda x: nn.gelu(x, approximate=True)
if fn_or_string == "squared_relu":
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
class TestFP8Dot: class TestFP8Dot:
@staticmethod @staticmethod
...@@ -293,14 +280,7 @@ class TestFP8Dot: ...@@ -293,14 +280,7 @@ class TestFP8Dot:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape) linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = jnp.split(linear_1_out, len(activation_type), axis=-2) x = _jax_act_lu(linear_1_out, activation_type)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
fp8_meta_pkg_2 = FP8MetaPackage( fp8_meta_pkg_2 = FP8MetaPackage(
amax_list_2[0], amax_list_2[0],
...@@ -443,12 +423,7 @@ class TestActivationLu: ...@@ -443,12 +423,7 @@ class TestActivationLu:
def ref_func(self, x, activation_type): def ref_func(self, x, activation_type):
def ref_act_lu(inputs): def ref_act_lu(inputs):
x = jnp.split(inputs, len(activation_type), axis=-2) x = _jax_act_lu(inputs, activation_type)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
return jnp.mean(x) return jnp.mean(x)
ref_act_func = jit(value_and_grad(ref_act_lu, (0,))) ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
......
...@@ -123,14 +123,12 @@ class SoftmaxRunner: ...@@ -123,14 +123,12 @@ class SoftmaxRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(lambda logits, *args: grad_func(softmax, logits, *args, **kwargs), (0,))
lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,)
)
) )
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda logits, *args: grad_func( lambda logits, *args: grad_func(
__class__.reference_softmax, self.logits, *args, **kwargs __class__.reference_softmax, logits, *args, **kwargs
), ),
(0,), (0,),
) )
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
"""JAX/TE custom ops for activation""" """JAX/TE custom ops for activation"""
from typing import Tuple, Sequence, Union, Callable from typing import Tuple, Sequence, Union, Callable
import operator import operator
from functools import reduce from functools import reduce, partial
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import core, dtypes from jax import core, dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
...@@ -22,6 +23,7 @@ from .misc import ( ...@@ -22,6 +23,7 @@ from .misc import (
jax_dtype_to_ir_dtype, jax_dtype_to_ir_dtype,
get_padded_spec, get_padded_spec,
) )
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP from ..sharding import all_reduce_max_along_all_axes_except_PP
...@@ -42,6 +44,35 @@ ActivationEnum = { ...@@ -42,6 +44,35 @@ ActivationEnum = {
} }
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "quick_gelu":
return lambda x: jax.nn.sigmoid(1.702 * x) * x
if fn_or_string == "squared_relu":
return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(jax.nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"Unsupported {fn_or_string} to an activation function")
def _jax_act_lu(inputs, activation_type):
"""
JAX native activation implementation
"""
x = jnp.split(inputs, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
return x
class ActLuPrimitive(BasePrimitive): class ActLuPrimitive(BasePrimitive):
""" """
Activation Forward Primitive Activation Forward Primitive
...@@ -155,6 +186,9 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) ...@@ -155,6 +186,9 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]])
Input shape: (N, 1, H) for non-gated activations Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations (N, 2, H) for gated activations
""" """
if not ActLuPrimitive.enabled():
return _jax_act_lu(inputs, activation_type)
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id) return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
...@@ -286,6 +320,11 @@ def dact_lu( ...@@ -286,6 +320,11 @@ def dact_lu(
dact_lu fusion wrapper dact_lu fusion wrapper
Return dgated_act_lu(inputs) Return dgated_act_lu(inputs)
""" """
if not DActLuPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
return vjp_func(inputs)[0]
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id) return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)
...@@ -443,6 +482,11 @@ def act_lu_fp8( ...@@ -443,6 +482,11 @@ def act_lu_fp8(
Input shape: (N, 1, H) for non-gated activations Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations (N, 2, H) for gated activations
""" """
if not ActLuFp8Primitive.enabled():
act_lu_output = _jax_act_lu(x, activation_type)
casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
return casted_output, updated_amax
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
return ActLuFp8Primitive.outer_primitive.bind( return ActLuFp8Primitive.outer_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE base custom ops""" """JAX/TE base custom ops"""
import os
import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from functools import partial from functools import partial
...@@ -17,6 +19,21 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -17,6 +19,21 @@ class BasePrimitive(metaclass=ABCMeta):
jax primitive jax primitive
""" """
name = None
@classmethod
def enabled(cls):
"""
A custom call is marked as disabled if the `cls.name` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`.
"""
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.name) is not None
return is_enabled
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def abstract(): def abstract():
......
...@@ -7,6 +7,7 @@ import operator ...@@ -7,6 +7,7 @@ import operator
import os import os
import warnings import warnings
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import core, dtypes from jax import core, dtypes
from jax.interpreters import mlir from jax.interpreters import mlir
...@@ -25,6 +26,7 @@ from .misc import ( ...@@ -25,6 +26,7 @@ from .misc import (
jax_dtype_to_ir_dtype, jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
) )
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
...@@ -239,12 +241,77 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -239,12 +241,77 @@ class LayerNormFwdPrimitive(BasePrimitive):
register_primitive(LayerNormFwdPrimitive) register_primitive(LayerNormFwdPrimitive)
def _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps):
"""
JAX native layernorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
gamma += 1.0
return jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
def _jax_rmsnorm(x, gamma, zero_centered_gamma, eps):
"""
JAX native rmsnorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
normed_input = x_ * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
gamma += 1.0
return jnp.asarray(normed_input * gamma).astype(x.dtype)
def _jax_layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps):
"""
JAX native layernorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = (x_ - mean) * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma + beta
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax
def _jax_rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps):
"""
JAX native rmsnorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = x_ * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax
def layernorm_fwd( def layernorm_fwd(
x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float
): ):
""" """
Wrapper for TE layernorm fwd Wrapper for TE layernorm fwd
""" """
if not LayerNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
mu = jnp.mean(x_, axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon)
return (
_jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon),
jnp.squeeze(mu, axis=-1),
jnp.squeeze(rsigma, axis=-1),
)
return LayerNormFwdPrimitive.outer_primitive.bind( return LayerNormFwdPrimitive.outer_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
) )
...@@ -468,12 +535,21 @@ def layernorm_bwd( ...@@ -468,12 +535,21 @@ def layernorm_bwd(
mu: jnp.ndarray, mu: jnp.ndarray,
rsigma: jnp.ndarray, rsigma: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
): ):
""" """
Wrapper for TE layernorm bwd Wrapper for TE layernorm bwd
""" """
if not LayerNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon),
x,
gamma,
beta,
)
return vjp_func(dz)
return LayerNormBwdPrimitive.outer_primitive.bind( return LayerNormBwdPrimitive.outer_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
) )
...@@ -655,6 +731,12 @@ def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float): ...@@ -655,6 +731,12 @@ def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
""" """
Wrapper for TE rmsnorm fwd Wrapper for TE rmsnorm fwd
""" """
if not RmsNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon)
return _jax_rmsnorm(x, gamma, zero_centered_gamma=False, eps=epsilon), jnp.squeeze(
rsigma, axis=-1
)
return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon) return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
...@@ -852,6 +934,11 @@ def rmsnorm_bwd( ...@@ -852,6 +934,11 @@ def rmsnorm_bwd(
""" """
Wrapper for TE layernorm bwd Wrapper for TE layernorm bwd
""" """
if not RmsNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, gamma
)
return vjp_func(dz)
return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
...@@ -1148,6 +1235,17 @@ def layernorm_fwd_fp8( ...@@ -1148,6 +1235,17 @@ def layernorm_fwd_fp8(
""" """
Wrapper for TE layernorm fwd (fp8 out) Wrapper for TE layernorm fwd (fp8 out)
""" """
if not LayerNormFwdFp8Primitive.enabled():
return _jax_layernorm_fp8(
x,
gamma,
beta,
scale,
amax,
out_dtype=out_dtype,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
)
return LayerNormFwdFp8Primitive.outer_primitive.bind( return LayerNormFwdFp8Primitive.outer_primitive.bind(
x, x,
gamma, gamma,
...@@ -1387,6 +1485,10 @@ def rmsnorm_fwd_fp8( ...@@ -1387,6 +1485,10 @@ def rmsnorm_fwd_fp8(
""" """
Wrapper for TE rmsnorm fwd (fp8 out) Wrapper for TE rmsnorm fwd (fp8 out)
""" """
if not RmsNormFwdFp8Primitive.enabled():
return _jax_rmsnorm_fp8(
x, gamma, scale, amax, out_dtype=out_dtype, zero_centered_gamma=False, eps=epsilon
)
return RmsNormFwdFp8Primitive.outer_primitive.bind( return RmsNormFwdFp8Primitive.outer_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
) )
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""JAX/TE custom ops for quantization""" """JAX/TE custom ops for quantization"""
from typing import Tuple from typing import Tuple
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
...@@ -26,6 +27,26 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP ...@@ -26,6 +27,26 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP
__all__ = ["cast_fp8"] __all__ = ["cast_fp8"]
def _jax_quantize(x, scale, q_dtype):
"""
Quantize with scale
"""
compute_dtype = scale.dtype
dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
def _jax_cast_fp8(inputs, scale, amax, out_dtype):
"""
JAX native fp8 casting implementation
"""
casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype)
updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype))
return casted_output, updated_amax
class CastFP8Primitive(BasePrimitive): class CastFP8Primitive(BasePrimitive):
""" """
Cast Primitive Cast Primitive
...@@ -157,4 +178,6 @@ def cast_fp8( ...@@ -157,4 +178,6 @@ def cast_fp8(
Cast wrapper Cast wrapper
Return FP8 tensor Return FP8 tensor
""" """
if not CastFP8Primitive.enabled():
return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype)
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
...@@ -7,6 +7,7 @@ from functools import partial, reduce ...@@ -7,6 +7,7 @@ from functools import partial, reduce
import operator import operator
import warnings import warnings
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import core, dtypes from jax import core, dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
...@@ -31,6 +32,30 @@ __all__ = [ ...@@ -31,6 +32,30 @@ __all__ = [
] ]
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
if mask is not None:
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def is_softmax_kernel_available( def is_softmax_kernel_available(
softmax_type: SoftmaxType, softmax_type: SoftmaxType,
batch: int, batch: int,
...@@ -395,6 +420,8 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: ...@@ -395,6 +420,8 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
scaled_softmax_forward wrapper scaled_softmax_forward wrapper
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
...@@ -469,12 +496,16 @@ register_primitive(ScaledSoftmaxBwdPrimitive) ...@@ -469,12 +496,16 @@ register_primitive(ScaledSoftmaxBwdPrimitive)
def scaled_softmax_bwd( def scaled_softmax_bwd(
dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
) -> jnp.ndarray: ) -> jnp.ndarray:
""" """
scaled_backward wrapper scaled_backward wrapper
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits)
return vjp_func(dz)[0]
return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor dz, softmax_out, scale_factor=scale_factor
) )
...@@ -625,6 +656,8 @@ def scaled_masked_softmax_fwd( ...@@ -625,6 +656,8 @@ def scaled_masked_softmax_fwd(
scaled_masked_softmax_forward wrapper scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor logits, mask, scale_factor=scale_factor
) )
...@@ -704,12 +737,21 @@ register_primitive(ScaledMaskedSoftmaxBwdPrimitive) ...@@ -704,12 +737,21 @@ register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
def scaled_masked_softmax_bwd( def scaled_masked_softmax_bwd(
dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float dz: jnp.ndarray,
softmax_out: jnp.ndarray,
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
) -> jnp.ndarray: ) -> jnp.ndarray:
""" """
scaled_masked_backward wrapper scaled_masked_backward wrapper
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
)
return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor dz, softmax_out, scale_factor=scale_factor
) )
...@@ -806,6 +848,8 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl ...@@ -806,6 +848,8 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl
scaled_upper_triang_masked_softmax_forward wrapper scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor logits, scale_factor=scale_factor
) )
...@@ -893,12 +937,17 @@ register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) ...@@ -893,12 +937,17 @@ register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def scaled_upper_triang_masked_softmax_bwd( def scaled_upper_triang_masked_softmax_bwd(
dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
) -> jnp.ndarray: ) -> jnp.ndarray:
""" """
scaled_upper_triang_masked_backward wrapper scaled_upper_triang_masked_backward wrapper
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
)
return vjp_func(dz)[0]
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor dz, softmax_out, scale_factor=scale_factor
) )
...@@ -6,6 +6,7 @@ from functools import partial, reduce ...@@ -6,6 +6,7 @@ from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable from typing import Tuple, Sequence, Union, Callable
import operator import operator
import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
...@@ -26,6 +27,8 @@ from .misc import ( ...@@ -26,6 +27,8 @@ from .misc import (
normalize_axis_boundary, normalize_axis_boundary,
) )
from .activation import ActivationEnum from .activation import ActivationEnum
from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
...@@ -38,6 +41,27 @@ __all__ = [ ...@@ -38,6 +41,27 @@ __all__ = [
] ]
def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary):
"""
JAX native transpose implementation
"""
axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary)
return jnp.transpose(inputs, axes=axes)
def _jax_cast_transpose(
inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
JAX native cast_transpose implementation
"""
casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype)
casted_transposed_output = _jax_transpose(
casted_output, static_axis_boundary, transpose_axis_boundary
)
return casted_output, casted_transposed_output, updated_amax
class TransposePrimitive(BasePrimitive): class TransposePrimitive(BasePrimitive):
""" """
Transpose Primitive Transpose Primitive
...@@ -176,6 +200,8 @@ def transpose( ...@@ -176,6 +200,8 @@ def transpose(
""" """
transpose wrapper transpose wrapper
""" """
if not TransposePrimitive.enabled():
return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary)
return TransposePrimitive.outer_primitive.bind( return TransposePrimitive.outer_primitive.bind(
x, x,
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
...@@ -381,6 +407,15 @@ def cast_transpose( ...@@ -381,6 +407,15 @@ def cast_transpose(
cast transpose wrapper cast transpose wrapper
Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
""" """
if not CastTransposePrimitive.enabled():
return _jax_cast_transpose(
x,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return CastTransposePrimitive.outer_primitive.bind( return CastTransposePrimitive.outer_primitive.bind(
x, x,
amax, amax,
...@@ -631,6 +666,28 @@ def dbias_cast_transpose( ...@@ -631,6 +666,28 @@ def dbias_cast_transpose(
if static_axis_boundary < 0: if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes static_axis_boundary = -1 # means no static axes
if not DBiasCastTransposePrimitive.enabled():
casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
dz,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.sum(
dz,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dz.ndim
)
),
keepdims=False,
)
return casted_dz, cast_transposed_dz, dbias, updated_amax
return DBiasCastTransposePrimitive.outer_primitive.bind( return DBiasCastTransposePrimitive.outer_primitive.bind(
dz, dz,
amax, amax,
...@@ -947,6 +1004,31 @@ def dact_lu_dbias_cast_transpose( ...@@ -947,6 +1004,31 @@ def dact_lu_dbias_cast_transpose(
if static_axis_boundary < 0: if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes static_axis_boundary = -1 # means no static axes
if not DActLuDBiasCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
casted_dx, cast_transposed_dx, updated_amax = _jax_cast_transpose(
dx,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.squeeze(
jnp.sum(
dx,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dx.ndim
)
),
)
)
return casted_dx, cast_transposed_dx, dbias, updated_amax
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz, dz,
...@@ -1161,6 +1243,17 @@ def dgated_act_lu_cast_transpose( ...@@ -1161,6 +1243,17 @@ def dgated_act_lu_cast_transpose(
Return FP8(dgated_act_lu(inputs)) Return FP8(dgated_act_lu(inputs))
""" """
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
if not DgatedActLuCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
return _jax_cast_transpose(
dx,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=-2,
)
return DgatedActLuCastTransposePrimitive.outer_primitive.bind( return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz, dz,
x, x,
......
...@@ -69,14 +69,14 @@ def _layernorm_fwd_rule( ...@@ -69,14 +69,14 @@ def _layernorm_fwd_rule(
mu = None mu = None
else: else:
raise ValueError(f"{layernorm_type=} is not supported.") raise ValueError(f"{layernorm_type=} is not supported.")
return output, (x, mu, rsigma, gamma) return output, (x, mu, rsigma, gamma, beta)
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz): def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
x, mu, rsigma, gamma = ctx x, mu, rsigma, gamma, beta = ctx
if layernorm_type == "layernorm": if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd( dx, dgamma, dbeta = tex.layernorm_bwd(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon dz, x, mu, rsigma, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
) )
elif layernorm_type == "rmsnorm": elif layernorm_type == "rmsnorm":
assert ( assert (
...@@ -267,6 +267,7 @@ def _layernorm_fp8_dot_fwd_rule( ...@@ -267,6 +267,7 @@ def _layernorm_fp8_dot_fwd_rule(
rsigma, rsigma,
x, x,
gamma, gamma,
beta,
x_contracting_dims, x_contracting_dims,
k_contracting_dims, k_contracting_dims,
maybe_fp32_to_fm32, maybe_fp32_to_fm32,
...@@ -300,6 +301,7 @@ def _layernorm_fp8_dot_bwd_rule( ...@@ -300,6 +301,7 @@ def _layernorm_fp8_dot_bwd_rule(
rsigma, rsigma,
x, x,
gamma, gamma,
beta,
x_contracting_dims, x_contracting_dims,
k_contracting_dims, k_contracting_dims,
maybe_fp32_to_fm32, maybe_fp32_to_fm32,
...@@ -352,7 +354,14 @@ def _layernorm_fp8_dot_bwd_rule( ...@@ -352,7 +354,14 @@ def _layernorm_fp8_dot_bwd_rule(
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
if layernorm_type == "layernorm": if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd( dx, dgamma, dbeta = tex.layernorm_bwd(
dgrad, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon dgrad,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
) )
else: else:
assert ( assert (
......
...@@ -344,6 +344,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ...@@ -344,6 +344,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
mu, mu,
rsigma, rsigma,
gamma, gamma,
beta,
dot_1_output, dot_1_output,
casted_activation_lu_out, casted_activation_lu_out,
casted_kernel_1, casted_kernel_1,
...@@ -390,6 +391,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -390,6 +391,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
mu, mu,
rsigma, rsigma,
gamma, gamma,
beta,
dot_1_output, dot_1_output,
casted_activation_lu_out, casted_activation_lu_out,
casted_kernel_1, casted_kernel_1,
...@@ -568,7 +570,14 @@ def _fused_layernorm_fp8_mlp_bwd_rule( ...@@ -568,7 +570,14 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
if layernorm_type == "layernorm": if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd( dx, dgamma, dbeta = tex.layernorm_bwd(
dgrad_1, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon dgrad_1,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
) )
else: else:
assert ( assert (
......
...@@ -49,18 +49,18 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): ...@@ -49,18 +49,18 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
else: else:
output = tex.scaled_softmax_fwd(logits, scale_factor) output = tex.scaled_softmax_fwd(logits, scale_factor)
return output, (output,) return output, (output, logits, mask)
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz): def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
(softmax_output,) = ctx (softmax_output, logits, mask) = ctx
if softmax_type is SoftmaxType.SCALED_MASKED: if softmax_type is SoftmaxType.SCALED_MASKED:
dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, scale_factor) dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor) dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor)
else: else:
dgrad = tex.scaled_softmax_bwd(dz, softmax_output, scale_factor) dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor)
return (dgrad, None) return (dgrad, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment