Unverified Commit b840898b authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[JAX] Clamped Swiglu Integration (#2194)


Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
*Jax integration for clamped swiglu. This is the continuation of PR which added Clamped Swiglu(used in GPT OSS) support in TE along with Pytorch integration. This PR hooks up the clamped swiglu and dswiglu's nvte APIs to TE Jax.
parent e30c36a3
......@@ -170,6 +170,7 @@ ALL_ACTIVATION_TYPES = [
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
]
ACTIVATION_TYPES = {
......@@ -182,17 +183,21 @@ ACTIVATION_TYPES = {
class TestActivation:
def ref_act(self, x, activation_type):
return _jax_act_lu(x, activation_type).data
def ref_act(self, x, activation_type, act_params):
return _jax_act_lu(x, activation_type, act_params=act_params).data
def value_n_grad_ref_func(self, x, activation_type):
def value_n_grad_ref_func(self, x, activation_type, act_params):
jitted_reference = jit(
value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
value_and_grad(
lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
)
)
return jitted_reference(x)
def primitive_func(self, inputs, activation_type, quantizer):
out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
def primitive_func(self, inputs, activation_type, quantizer, act_params):
out = activation(
inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
return jnp.mean(out)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
......@@ -209,12 +214,20 @@ class TestActivation:
x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
......@@ -234,7 +247,8 @@ class TestActivation:
self.activation_type = activation_type
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)),
static_argnums=(1, 3),
)
quantizer = QuantizerFactory.create(
......@@ -242,9 +256,21 @@ class TestActivation:
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(
x, activation_type, quantizer, act_params
)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type)
......@@ -273,10 +299,18 @@ class TestActivation:
q_dtype=output_type,
q_layout=q_layout,
)
te_output = tex.act_lu(x, activation_type, te_quantizer)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
......@@ -296,10 +330,18 @@ class TestActivation:
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)
output = tex.act_lu(x, activation_type, quantizer)
ref_out = self.ref_act(x, activation_type)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
output = tex.act_lu(x, activation_type, quantizer, act_params)
ref_out = self.ref_act(x, activation_type, act_params)
assert_dequantized_scaled_tensor(output, ref_out)
......@@ -734,6 +776,7 @@ class TestFusedQuantize:
def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
......@@ -785,7 +828,7 @@ class TestFusedQuantize:
(in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
or (
activation_type == ("squared_relu",)
activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
and in_dtype == jnp.bfloat16
and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
)
......
......@@ -39,6 +39,7 @@ enum class NVTE_Activation_Type {
QGEGLU,
SRELU,
SREGLU,
CLAMPED_SWIGLU
};
/*! \brief Computes the GeLU activation of the input.
......
......@@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
checkCuDriverContext(stream);
......@@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
checkCuDriverContext(stream);
......@@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
case ScalingType::COLWISE:
......@@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
......@@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) {
void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
......@@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p,
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
......@@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input");
......@@ -1318,7 +1316,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
ParamOP &p, cudaStream_t stream) {
ParamOP p, cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
......
......@@ -11,7 +11,6 @@ from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import NoScaleTensor
......@@ -22,6 +21,7 @@ def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
act_params: Optional[tex.activation.ActivationParams] = None,
) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization.
......@@ -32,17 +32,19 @@ def activation(
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
output = _activation(x, activation_type, quantizer, act_params)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _activation(x, activation_type, quantizer, act_params):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
......@@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer):
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
_output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params)
return _output
def _activation_fwd_rule(x, activation_type, quantizer):
def _activation_fwd_rule(x, activation_type, quantizer, act_params):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
fwd_output = tex.act_lu(x, activation_type, quantizer, act_params)
# This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g):
def _activation_bwd_rule(activation_type, act_params, ctx, g):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
ctx: Context from forward pass
g: Gradient from upstream
......@@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g):
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = tex.dact_lu(g, x, activation_type, act_params=act_params)
# No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor)
......
......@@ -5,6 +5,7 @@
from typing import Sequence, Union, Callable, Optional, Tuple
import operator
from functools import reduce, partial
from dataclasses import dataclass
import jax
import jax.numpy as jnp
......@@ -12,9 +13,9 @@ from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
import numpy as np
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
from .base import BasePrimitive, register_primitive
from .misc import (
jax_dtype_to_te_dtype,
......@@ -51,17 +52,87 @@ ActivationEnum = {
("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
("squared_relu",): NVTE_Activation_Type.SRELU,
("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU,
}
def _convert_to_activation_function(fn_or_string):
@dataclass(frozen=True)
class ClampedSwigluParams:
"""Parameters for the Clamped SwiGLU activation function
used in GPT OSS."""
limit: float = 7.0
alpha: float = 1.702
def __hash__(self):
"""Custom hash function to ensure dataclass is hashable for jax jit to work.
Returns:
int: Hash value of the dataclass instance.
"""
return hash((self.limit, self.alpha))
def to_ffi_lowering_dict(self):
"""Convert the activation parameters to a dictionary format for FFI lowering.
Returns:
dict: A dictionary representation of the activation parameters consumable by
XLA FFI bindings for activation functions.
"""
return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}
@dataclass(frozen=True)
class ActivationParams:
"""Parameters for various activation functions.
Currently only Clamped SwiGLU activation has parameters.
"""
clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams()
@staticmethod
def create(activation_type, **kwargs):
"""Factory method to create ActivationParams based on activation_type."""
CLAMPED_ACTIVATION_TYPES = {
("clamped_silu", "clamped_linear"),
"clamped_silu",
"clamped_linear",
}
if activation_type in CLAMPED_ACTIVATION_TYPES:
return ActivationParams(ClampedSwigluParams(**kwargs))
return ActivationParams() # Default params for activations without parameters
def __hash__(self):
"""Custom hash function to ensure dataclass is hashable for jax jit to work"""
return hash((self.clamped_swiglu,))
def to_ffi_lowering_dict(self):
"""Convert the activation parameters to a dictionary format for FFI lowering.
Returns:
dict: A dictionary representation of the activation parameters consumable by
XLA FFI bindings for activation functions.
"""
return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()}
def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
"""Convert a string to an activation function."""
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "clamped_linear":
# This function is used for ClampedSwiGLU
# used in GPT OSS where the gates are not only clamped
# but also shifted by +1
limit = act_params.clamped_swiglu.limit
return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
if fn_or_string == "quick_gelu":
return lambda x: jax.nn.sigmoid(1.702 * x) * x
if fn_or_string == "squared_relu":
return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
if fn_or_string == "clamped_silu":
limit = act_params.clamped_swiglu.limit
alpha = act_params.clamped_swiglu.alpha
return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit)
if isinstance(fn_or_string, str):
return getattr(jax.nn, fn_or_string)
if callable(fn_or_string):
......@@ -84,7 +155,8 @@ class ActLuPrimitive(BasePrimitive):
6,
7,
8,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
9,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params
inner_primitive = None
outer_primitive = None
......@@ -100,11 +172,12 @@ class ActLuPrimitive(BasePrimitive):
is_2x,
scale_dtype,
is_outer,
act_params,
):
"""
te_act_lu_p abstract
"""
del act_enum
del act_enum, act_params
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
......@@ -150,6 +223,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x,
scale_dtype,
is_outer,
act_params,
):
"""
te_gated_act_lu_p lowering rules
......@@ -158,9 +232,14 @@ class ActLuPrimitive(BasePrimitive):
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
out = ffi.ffi_lowering(ActLuPrimitive.name)(
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
ctx,
x,
scale,
act_enum=act_enum,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
act_params=act_params.to_ffi_lowering_dict(),
)
return out
......@@ -175,6 +254,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x,
scale_dtype,
is_outer,
act_params,
):
"""
to describe implementation
......@@ -193,6 +273,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x=is_2x,
scale_dtype=scale_dtype,
is_outer=False,
act_params=act_params,
)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -221,6 +302,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x,
scale_dtype,
is_outer,
act_params,
):
"""
to describe batch rules for vmap
......@@ -242,6 +324,7 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
act_params=act_params,
),
out_bdims,
)
......@@ -255,6 +338,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x,
scale_dtype,
is_outer,
act_params,
mesh,
arg_infos,
result_infos,
......@@ -266,6 +350,7 @@ class ActLuPrimitive(BasePrimitive):
scale_dtype,
act_len,
is_outer,
act_params,
) # Unused.
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
......@@ -318,6 +403,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x,
scale_dtype,
is_outer,
act_params,
mesh,
arg_infos,
result_infos,
......@@ -378,6 +464,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x=is_2x,
scale_dtype=scale_dtype,
is_outer=True,
act_params=act_params,
)
)
......@@ -405,11 +492,12 @@ class ActLuPrimitive(BasePrimitive):
is_2x,
scale_dtype,
is_outer,
act_params,
mesh,
value_types,
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params
prefix = "ActLu_"
input_shape = value_types[0].shape
output_shape = input_shape[:-2] + input_shape[-1:]
......@@ -455,8 +543,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
name = "te_dact_dbias_quantize_ffi"
multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
inner_primitive = None
outer_primitive = None
......@@ -474,11 +562,12 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum,
act_len,
is_outer,
act_params,
):
"""
te_dact_dbias_quantize_p abstract
"""
del act_enum
del act_enum, act_params
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype
......@@ -575,6 +664,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum,
act_len,
is_outer,
act_params,
):
"""
te_dact_dbias_quantize_p lowering rules
......@@ -593,6 +683,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_2x=is_2x,
is_dbias=is_dbias,
act_enum=int(act_enum),
act_params=act_params.to_ffi_lowering_dict(),
)
@staticmethod
......@@ -608,6 +699,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum,
act_len,
is_outer,
act_params,
):
"""
te_dact_dbias_quantize_p impl
......@@ -627,6 +719,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum=act_enum,
act_len=act_len,
is_outer=False,
act_params=act_params,
)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -655,6 +748,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum,
act_len,
is_outer,
act_params,
):
"""
to describe batch rules for vmap
......@@ -685,6 +779,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
act_params=act_params,
),
out_bdims,
)
......@@ -699,11 +794,12 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum,
act_len,
is_outer,
act_params,
mesh,
arg_infos,
result_infos,
):
del out_dtype, result_infos, act_enum
del out_dtype, result_infos, act_enum, act_params
del scale_dtype, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
......@@ -774,6 +870,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum,
act_len,
is_outer,
act_params,
mesh,
arg_infos,
result_infos,
......@@ -854,6 +951,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum=act_enum,
act_len=act_len,
is_outer=True,
act_params=act_params,
)
)
if is_dbias:
......@@ -880,11 +978,13 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum,
act_len,
is_outer,
act_params,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params
prefix = "DActLuDBias_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
......@@ -923,20 +1023,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]:
def _jax_act_lu(
inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None
) -> Union[NoScaleTensor, ScaledTensor]:
"""
JAX native activation implementation
"""
act_params = act_params if act_params is not None else ActivationParams()
act_len = len(activation_type)
assert inputs.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {inputs.shape} and act_len {act_len}"
)
x = jnp.split(inputs, act_len, axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
x_i = _convert_to_activation_function(act_fn, act_params)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
......@@ -951,10 +1053,12 @@ def _jax_quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]],
is_dbias: bool = True,
quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
):
"""
JAX implementation of dact_lu and dbias with optional quantization
"""
act_params = act_params if act_params is not None else ActivationParams()
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
......@@ -962,7 +1066,8 @@ def _jax_quantize_dact_dbias(
)
_, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
partial(_jax_act_lu, activation_type=activation_type, act_params=act_params),
x.astype(jnp.float32),
)
# VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
......@@ -985,6 +1090,7 @@ def act_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization.
......@@ -1008,24 +1114,22 @@ def act_lu(
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
act_params = act_params if act_params is not None else ActivationParams()
if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer)
return _jax_act_lu(x, activation_type, quantizer, act_params)
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_act_lu(x, activation_type, quantizer)
return _jax_act_lu(x, activation_type, quantizer, act_params)
# TE/common does not support 2x quantization for DelayedScaling yet
war_output = try_apply_delayed_scaling_2x_war(
f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
if war_output is not None:
return war_output
scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-2], x.shape[-1])
if quantizer is None:
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
x,
......@@ -1037,6 +1141,7 @@ def act_lu(
is_2x=False,
scale_dtype=jnp.float32,
is_outer=True,
act_params=act_params,
)
out = out.reshape(output_shape)
out = NoScaleTensor(
......@@ -1051,6 +1156,7 @@ def act_lu(
x=x,
activation_type=activation_type,
quantizer=None,
act_params=act_params,
)
out, _ = _quantize_dbias_impl(
out,
......@@ -1060,7 +1166,6 @@ def act_lu(
amax_scope=amax_scope,
)
return out
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
......@@ -1080,6 +1185,7 @@ def act_lu(
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
is_outer=True,
act_params=act_params,
)
quantizer.update(updated_amax)
......@@ -1102,6 +1208,7 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True,
quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization.
......@@ -1118,7 +1225,7 @@ def quantize_dact_dbias(
- The gradient of the activation with respect to the input.
- The gradient of the activation with respect to the bias.
"""
act_params = act_params if act_params is not None else ActivationParams()
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
......@@ -1131,8 +1238,7 @@ def quantize_dact_dbias(
if not PrimitiveClass.enabled() or (
quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params)
if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz,
......@@ -1148,6 +1254,7 @@ def quantize_dact_dbias(
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
act_params=act_params,
)
output = output.astype(x.dtype)
dbias = None
......@@ -1163,7 +1270,11 @@ def quantize_dact_dbias(
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
dz.astype(jnp.float32),
x.astype(jnp.float32),
activation_type,
quantizer=None,
act_params=act_params,
)
return _quantize_dbias_impl(
out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
......@@ -1180,6 +1291,7 @@ def quantize_dact_dbias(
is_dbias=is_dbias,
quantizer=quantizer,
flatten_axis=-2,
act_params=act_params,
)
if war_output is not None:
return war_output
......@@ -1191,6 +1303,7 @@ def quantize_dact_dbias(
x=x,
activation_type=activation_type,
quantizer=None,
act_params=act_params,
)
out, dbias = _quantize_dbias_impl(
out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
......@@ -1203,7 +1316,10 @@ def quantize_dact_dbias(
# TE/common dact_dbias_quantize does not support gated act yet
if is_dbias and is_gated:
dgated = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
dz.astype(jnp.float32),
x.astype(jnp.float32),
activation_type=activation_type,
act_params=act_params,
)
out, dbias = _quantize_dbias_impl(
dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
......@@ -1229,6 +1345,7 @@ def quantize_dact_dbias(
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
act_params=act_params,
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
......@@ -1257,6 +1374,7 @@ def dact_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""
Backward pass for activation with optional quantization.
......@@ -1270,11 +1388,13 @@ def dact_lu(
Returns:
The gradient of the activation with respect to the input.
"""
act_params = act_params if act_params is not None else ActivationParams()
output, _ = quantize_dact_dbias(
dz=dz,
x=x,
activation_type=activation_type,
is_dbias=False,
quantizer=quantizer,
act_params=act_params,
)
return output
......@@ -36,6 +36,15 @@
namespace transformer_engine {
namespace jax {
struct ClampedSwigluConfig {
float limit;
float alpha;
};
struct ActivationConfig {
ClampedSwigluConfig clamped_swiglu;
};
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
// Activation
......@@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax
} // namespace transformer_engine
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig,
::xla::ffi::StructMember<float>("limit"),
::xla::ffi::StructMember<float>("alpha"));
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
transformer_engine::jax::ActivationConfig,
::xla::ffi::StructMember<transformer_engine::jax::ClampedSwigluConfig>("clamped_swiglu"));
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
......
......@@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int) {
bool is_2x_int, ActivationConfig act_params) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
......@@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
......@@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"),
.Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"),
FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int) {
JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
act_enum, scaling_mode, is_2x_int);
act_enum, scaling_mode, is_2x_int, act_params);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
......@@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
.Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
......@@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias) {
int64_t act_enum, bool is_2x, bool is_dbias,
ActivationConfig act_params) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
swiglu_limit, swiglu_alpha, stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
......@@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias"),
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"),
FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
bool is_2x, bool is_dbias) {
Error_Type DActLuDBiasQuantizeInitializeFFI(
cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
bool is_dbias, ActivationConfig act_params) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias);
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
......@@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias"));
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"));
} // namespace jax
} // namespace transformer_engine
......@@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU)
.value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU)
.export_values();
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
......
......@@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase):
activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer.
activation_params: dict, default = None
The parameters needed(if any) by the activation functions specified in :attr:`activations`.
At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1
......@@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",)
activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
......@@ -1023,6 +1028,7 @@ class LayerNormMLP(TransformerEngineBase):
("relu", "linear"),
("quick_gelu", "linear"),
("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
]
act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
normalized_acts = []
......@@ -1031,7 +1037,9 @@ class LayerNormMLP(TransformerEngineBase):
return False
normalized_acts.append(act.lower())
normalized_acts = tuple(
reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
reversed(normalized_acts)
if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
else normalized_acts
)
is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
......@@ -1150,6 +1158,7 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts,
activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
)
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
......@@ -1287,4 +1296,4 @@ class LayerNormMLP(TransformerEngineBase):
out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output
return out, ln_output # Output, layer_norm_output
......@@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_activations: Sequence[str], default = ('relu', )
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
mlp_activation_params: dict = None
This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment
ClampedSwiglu is the only activation that requires parameters.
use_bias: bool, default = False
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases.
......@@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",)
mlp_activation_params: dict = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
......@@ -2046,6 +2050,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations,
activation_params=self.mlp_activation_params,
intermediate_dropout_rng_name=self.dropout_rng_name,
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
......
......@@ -50,6 +50,7 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
activation_params: dict = None,
collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set,
tex.noop_collective_op_set,
......@@ -138,13 +139,14 @@ def layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
activation_params,
collective_op_sets,
quantizer_sets,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -165,6 +167,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
activation_params: dict,
collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets,
):
......@@ -220,6 +223,7 @@ def _layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
activation_params,
collective_op_sets,
quantizer_sets,
)
......@@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
activation_params,
collective_op_sets,
quantizer_sets,
):
......@@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule(
dot_1_output,
activation_type,
quantizer=ffn2_quantizer_set.x,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
......@@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
activation_params,
collective_op_sets,
ctx,
grad,
......@@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type,
is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment