Unverified Commit 6c579267 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Allow enabling partial custom calls through the environment variable (#1007)



* Add enabled() to BasePrimitive

* Add layernorm/rmsnorm fallback

* Add cast_fp8 fallback

* Add transpose/cast_transpose XLA fall back

* Act_lu fallback

* Add transpose fallback

* Add softmax fallback

* Unify the use of _cast_fp8

* Add tests for NVTE_JAX_CUSTOM_CALLS_RE

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 210e57de
......@@ -9,6 +9,9 @@ pip install pytest==8.2.1
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
......
......@@ -19,8 +19,10 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quanti
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax import cpp_extensions as tex
GEMM_CASES = [
(256, 256, 512),
(32, 32, 32),
......@@ -34,21 +36,6 @@ DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "quick_gelu":
return lambda x: nn.gelu(x, approximate=True)
if fn_or_string == "squared_relu":
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
class TestFP8Dot:
@staticmethod
......@@ -293,14 +280,7 @@ class TestFP8Dot:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = jnp.split(linear_1_out, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
x = _jax_act_lu(linear_1_out, activation_type)
fp8_meta_pkg_2 = FP8MetaPackage(
amax_list_2[0],
......@@ -443,12 +423,7 @@ class TestActivationLu:
def ref_func(self, x, activation_type):
def ref_act_lu(inputs):
x = jnp.split(inputs, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
x = _jax_act_lu(inputs, activation_type)
return jnp.mean(x)
ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
......
......@@ -123,14 +123,12 @@ class SoftmaxRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,)
)
value_and_grad(lambda logits, *args: grad_func(softmax, logits, *args, **kwargs), (0,))
)
jitted_reference = jit(
value_and_grad(
lambda logits, *args: grad_func(
__class__.reference_softmax, self.logits, *args, **kwargs
__class__.reference_softmax, logits, *args, **kwargs
),
(0,),
)
......
......@@ -4,8 +4,9 @@
"""JAX/TE custom ops for activation"""
from typing import Tuple, Sequence, Union, Callable
import operator
from functools import reduce
from functools import reduce, partial
import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters.mlir import ir
......@@ -22,6 +23,7 @@ from .misc import (
jax_dtype_to_ir_dtype,
get_padded_spec,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP
......@@ -42,6 +44,35 @@ ActivationEnum = {
}
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "quick_gelu":
return lambda x: jax.nn.sigmoid(1.702 * x) * x
if fn_or_string == "squared_relu":
return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(jax.nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"Unsupported {fn_or_string} to an activation function")
def _jax_act_lu(inputs, activation_type):
"""
JAX native activation implementation
"""
x = jnp.split(inputs, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
return x
class ActLuPrimitive(BasePrimitive):
"""
Activation Forward Primitive
......@@ -155,6 +186,9 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]])
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
if not ActLuPrimitive.enabled():
return _jax_act_lu(inputs, activation_type)
act_type_id = ActivationEnum[activation_type]
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
......@@ -286,6 +320,11 @@ def dact_lu(
dact_lu fusion wrapper
Return dgated_act_lu(inputs)
"""
if not DActLuPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
return vjp_func(inputs)[0]
act_type_id = ActivationEnum[activation_type]
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)
......@@ -443,6 +482,11 @@ def act_lu_fp8(
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
if not ActLuFp8Primitive.enabled():
act_lu_output = _jax_act_lu(x, activation_type)
casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
return casted_output, updated_amax
act_type_id = ActivationEnum[activation_type]
return ActLuFp8Primitive.outer_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
......
......@@ -2,6 +2,8 @@
#
# See LICENSE for license information.
"""JAX/TE base custom ops"""
import os
import re
from abc import ABCMeta, abstractmethod
from functools import partial
......@@ -17,6 +19,21 @@ class BasePrimitive(metaclass=ABCMeta):
jax primitive
"""
name = None
@classmethod
def enabled(cls):
"""
A custom call is marked as disabled if the `cls.name` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`.
"""
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.name) is not None
return is_enabled
@staticmethod
@abstractmethod
def abstract():
......
......@@ -7,6 +7,7 @@ import operator
import os
import warnings
import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters import mlir
......@@ -25,6 +26,7 @@ from .misc import (
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
......@@ -239,12 +241,77 @@ class LayerNormFwdPrimitive(BasePrimitive):
register_primitive(LayerNormFwdPrimitive)
def _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps):
"""
JAX native layernorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
gamma += 1.0
return jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
def _jax_rmsnorm(x, gamma, zero_centered_gamma, eps):
"""
JAX native rmsnorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
normed_input = x_ * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
gamma += 1.0
return jnp.asarray(normed_input * gamma).astype(x.dtype)
def _jax_layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps):
"""
JAX native layernorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = (x_ - mean) * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma + beta
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax
def _jax_rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps):
"""
JAX native rmsnorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = x_ * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax
def layernorm_fwd(
x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float
):
"""
Wrapper for TE layernorm fwd
"""
if not LayerNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
mu = jnp.mean(x_, axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon)
return (
_jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon),
jnp.squeeze(mu, axis=-1),
jnp.squeeze(rsigma, axis=-1),
)
return LayerNormFwdPrimitive.outer_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
......@@ -468,12 +535,21 @@ def layernorm_bwd(
mu: jnp.ndarray,
rsigma: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
):
"""
Wrapper for TE layernorm bwd
"""
if not LayerNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon),
x,
gamma,
beta,
)
return vjp_func(dz)
return LayerNormBwdPrimitive.outer_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
......@@ -655,6 +731,12 @@ def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
"""
Wrapper for TE rmsnorm fwd
"""
if not RmsNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon)
return _jax_rmsnorm(x, gamma, zero_centered_gamma=False, eps=epsilon), jnp.squeeze(
rsigma, axis=-1
)
return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
......@@ -852,6 +934,11 @@ def rmsnorm_bwd(
"""
Wrapper for TE layernorm bwd
"""
if not RmsNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, gamma
)
return vjp_func(dz)
return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
......@@ -1148,6 +1235,17 @@ def layernorm_fwd_fp8(
"""
Wrapper for TE layernorm fwd (fp8 out)
"""
if not LayerNormFwdFp8Primitive.enabled():
return _jax_layernorm_fp8(
x,
gamma,
beta,
scale,
amax,
out_dtype=out_dtype,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
)
return LayerNormFwdFp8Primitive.outer_primitive.bind(
x,
gamma,
......@@ -1387,6 +1485,10 @@ def rmsnorm_fwd_fp8(
"""
Wrapper for TE rmsnorm fwd (fp8 out)
"""
if not RmsNormFwdFp8Primitive.enabled():
return _jax_rmsnorm_fp8(
x, gamma, scale, amax, out_dtype=out_dtype, zero_centered_gamma=False, eps=epsilon
)
return RmsNormFwdFp8Primitive.outer_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
......@@ -4,6 +4,7 @@
"""JAX/TE custom ops for quantization"""
from typing import Tuple
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
......@@ -26,6 +27,26 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP
__all__ = ["cast_fp8"]
def _jax_quantize(x, scale, q_dtype):
"""
Quantize with scale
"""
compute_dtype = scale.dtype
dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
def _jax_cast_fp8(inputs, scale, amax, out_dtype):
"""
JAX native fp8 casting implementation
"""
casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype)
updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype))
return casted_output, updated_amax
class CastFP8Primitive(BasePrimitive):
"""
Cast Primitive
......@@ -157,4 +178,6 @@ def cast_fp8(
Cast wrapper
Return FP8 tensor
"""
if not CastFP8Primitive.enabled():
return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype)
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
......@@ -7,6 +7,7 @@ from functools import partial, reduce
import operator
import warnings
import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters.mlir import ir
......@@ -31,6 +32,30 @@ __all__ = [
]
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
if mask is not None:
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def is_softmax_kernel_available(
softmax_type: SoftmaxType,
batch: int,
......@@ -395,6 +420,8 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
scaled_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
......@@ -469,12 +496,16 @@ register_primitive(ScaledSoftmaxBwdPrimitive)
def scaled_softmax_bwd(
dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float
dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
"""
scaled_backward wrapper
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits)
return vjp_func(dz)[0]
return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor
)
......@@ -625,6 +656,8 @@ def scaled_masked_softmax_fwd(
scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor
)
......@@ -704,12 +737,21 @@ register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
def scaled_masked_softmax_bwd(
dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float
dz: jnp.ndarray,
softmax_out: jnp.ndarray,
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
) -> jnp.ndarray:
"""
scaled_masked_backward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
)
return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor
)
......@@ -806,6 +848,8 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl
scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor
)
......@@ -893,12 +937,17 @@ register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def scaled_upper_triang_masked_softmax_bwd(
dz: jnp.ndarray, softmax_out: jnp.ndarray, scale_factor: float
dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
"""
scaled_upper_triang_masked_backward wrapper
Return FP16/BF16 tensor
"""
if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
)
return vjp_func(dz)[0]
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor
)
......@@ -6,6 +6,7 @@ from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable
import operator
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
......@@ -26,6 +27,8 @@ from .misc import (
normalize_axis_boundary,
)
from .activation import ActivationEnum
from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
......@@ -38,6 +41,27 @@ __all__ = [
]
def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary):
"""
JAX native transpose implementation
"""
axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary)
return jnp.transpose(inputs, axes=axes)
def _jax_cast_transpose(
inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
JAX native cast_transpose implementation
"""
casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype)
casted_transposed_output = _jax_transpose(
casted_output, static_axis_boundary, transpose_axis_boundary
)
return casted_output, casted_transposed_output, updated_amax
class TransposePrimitive(BasePrimitive):
"""
Transpose Primitive
......@@ -176,6 +200,8 @@ def transpose(
"""
transpose wrapper
"""
if not TransposePrimitive.enabled():
return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary)
return TransposePrimitive.outer_primitive.bind(
x,
static_axis_boundary=static_axis_boundary,
......@@ -381,6 +407,15 @@ def cast_transpose(
cast transpose wrapper
Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
"""
if not CastTransposePrimitive.enabled():
return _jax_cast_transpose(
x,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return CastTransposePrimitive.outer_primitive.bind(
x,
amax,
......@@ -631,6 +666,28 @@ def dbias_cast_transpose(
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
if not DBiasCastTransposePrimitive.enabled():
casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
dz,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.sum(
dz,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dz.ndim
)
),
keepdims=False,
)
return casted_dz, cast_transposed_dz, dbias, updated_amax
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
......@@ -947,6 +1004,31 @@ def dact_lu_dbias_cast_transpose(
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
if not DActLuDBiasCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
casted_dx, cast_transposed_dx, updated_amax = _jax_cast_transpose(
dx,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.squeeze(
jnp.sum(
dx,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dx.ndim
)
),
)
)
return casted_dx, cast_transposed_dx, dbias, updated_amax
act_type_id = ActivationEnum[activation_type]
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
......@@ -1161,6 +1243,17 @@ def dgated_act_lu_cast_transpose(
Return FP8(dgated_act_lu(inputs))
"""
act_type_id = ActivationEnum[activation_type]
if not DgatedActLuCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
return _jax_cast_transpose(
dx,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=-2,
)
return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz,
x,
......
......@@ -69,14 +69,14 @@ def _layernorm_fwd_rule(
mu = None
else:
raise ValueError(f"{layernorm_type=} is not supported.")
return output, (x, mu, rsigma, gamma)
return output, (x, mu, rsigma, gamma, beta)
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
x, mu, rsigma, gamma = ctx
x, mu, rsigma, gamma, beta = ctx
if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
dz, x, mu, rsigma, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
elif layernorm_type == "rmsnorm":
assert (
......@@ -267,6 +267,7 @@ def _layernorm_fp8_dot_fwd_rule(
rsigma,
x,
gamma,
beta,
x_contracting_dims,
k_contracting_dims,
maybe_fp32_to_fm32,
......@@ -300,6 +301,7 @@ def _layernorm_fp8_dot_bwd_rule(
rsigma,
x,
gamma,
beta,
x_contracting_dims,
k_contracting_dims,
maybe_fp32_to_fm32,
......@@ -352,7 +354,14 @@ def _layernorm_fp8_dot_bwd_rule(
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd(
dgrad, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
dgrad,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
else:
assert (
......
......@@ -344,6 +344,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
mu,
rsigma,
gamma,
beta,
dot_1_output,
casted_activation_lu_out,
casted_kernel_1,
......@@ -390,6 +391,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
mu,
rsigma,
gamma,
beta,
dot_1_output,
casted_activation_lu_out,
casted_kernel_1,
......@@ -568,7 +570,14 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
if layernorm_type == "layernorm":
dx, dgamma, dbeta = tex.layernorm_bwd(
dgrad_1, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
dgrad_1,
x,
mu,
rsigma,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
else:
assert (
......
......@@ -49,18 +49,18 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
else:
output = tex.scaled_softmax_fwd(logits, scale_factor)
return output, (output,)
return output, (output, logits, mask)
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
(softmax_output,) = ctx
(softmax_output, logits, mask) = ctx
if softmax_type is SoftmaxType.SCALED_MASKED:
dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, scale_factor)
dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor)
dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor)
else:
dgrad = tex.scaled_softmax_bwd(dz, softmax_output, scale_factor)
dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor)
return (dgrad, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment