Unverified Commit b840898b authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[JAX] Clamped Swiglu Integration (#2194)


Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
*Jax integration for clamped swiglu. This is the continuation of PR which added Clamped Swiglu(used in GPT OSS) support in TE along with Pytorch integration. This PR hooks up the clamped swiglu and dswiglu's nvte APIs to TE Jax.
parent e30c36a3
...@@ -170,6 +170,7 @@ ALL_ACTIVATION_TYPES = [ ...@@ -170,6 +170,7 @@ ALL_ACTIVATION_TYPES = [
("quick_gelu", "linear"), ("quick_gelu", "linear"),
("squared_relu",), ("squared_relu",),
("squared_relu", "linear"), ("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
] ]
ACTIVATION_TYPES = { ACTIVATION_TYPES = {
...@@ -182,17 +183,21 @@ ACTIVATION_TYPES = { ...@@ -182,17 +183,21 @@ ACTIVATION_TYPES = {
class TestActivation: class TestActivation:
def ref_act(self, x, activation_type): def ref_act(self, x, activation_type, act_params):
return _jax_act_lu(x, activation_type).data return _jax_act_lu(x, activation_type, act_params=act_params).data
def value_n_grad_ref_func(self, x, activation_type): def value_n_grad_ref_func(self, x, activation_type, act_params):
jitted_reference = jit( jitted_reference = jit(
value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) value_and_grad(
lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
)
) )
return jitted_reference(x) return jitted_reference(x)
def primitive_func(self, inputs, activation_type, quantizer): def primitive_func(self, inputs, activation_type, quantizer, act_params):
out = activation(inputs, activation_type=activation_type, quantizer=quantizer) out = activation(
inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
return jnp.mean(out) return jnp.mean(out)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
...@@ -209,12 +214,20 @@ class TestActivation: ...@@ -209,12 +214,20 @@ class TestActivation:
x = jnp.repeat(x, len(activation_type), axis=-2) x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
) )
act_args = (
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) {"limit": 0.75, "alpha": 1.702}
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
...@@ -234,7 +247,8 @@ class TestActivation: ...@@ -234,7 +247,8 @@ class TestActivation:
self.activation_type = activation_type self.activation_type = activation_type
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) value_and_grad(self.primitive_func, (0,)),
static_argnums=(1, 3),
) )
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
...@@ -242,9 +256,21 @@ class TestActivation: ...@@ -242,9 +256,21 @@ class TestActivation:
q_dtype=output_type, q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) act_params = (
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(
x, activation_type, quantizer, act_params
)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type)
...@@ -273,10 +299,18 @@ class TestActivation: ...@@ -273,10 +299,18 @@ class TestActivation:
q_dtype=output_type, q_dtype=output_type,
q_layout=q_layout, q_layout=q_layout,
) )
act_args = (
te_output = tex.act_lu(x, activation_type, te_quantizer) {"limit": 0.75, "alpha": 1.702}
jax_output = _jax_act_lu(x, activation_type, jax_quantizer) if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
assert_bitwise_scaled_tensors(te_output, jax_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
...@@ -296,10 +330,18 @@ class TestActivation: ...@@ -296,10 +330,18 @@ class TestActivation:
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
) )
act_args = (
output = tex.act_lu(x, activation_type, quantizer) {"limit": 0.75, "alpha": 1.702}
ref_out = self.ref_act(x, activation_type) if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
output = tex.act_lu(x, activation_type, quantizer, act_params)
ref_out = self.ref_act(x, activation_type, act_params)
assert_dequantized_scaled_tensor(output, ref_out) assert_dequantized_scaled_tensor(output, ref_out)
...@@ -734,6 +776,7 @@ class TestFusedQuantize: ...@@ -734,6 +776,7 @@ class TestFusedQuantize:
def _test_quantize_dact_dbias( def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
): ):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
...@@ -785,7 +828,7 @@ class TestFusedQuantize: ...@@ -785,7 +828,7 @@ class TestFusedQuantize:
(in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
or ( or (
activation_type == ("squared_relu",) activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
and in_dtype == jnp.bfloat16 and in_dtype == jnp.bfloat16
and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
) )
......
...@@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { ...@@ -39,6 +39,7 @@ enum class NVTE_Activation_Type {
QGEGLU, QGEGLU,
SRELU, SRELU,
SREGLU, SREGLU,
CLAMPED_SWIGLU
}; };
/*! \brief Computes the GeLU activation of the input. /*! \brief Computes the GeLU activation of the input.
......
...@@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
...@@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
...@@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
} }
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input"); CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output"); CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.flat_last_dim() % 2 == 0, NVTE_CHECK(input.flat_last_dim() % 2 == 0,
...@@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st ...@@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st
template <typename ParamOP, float (*ActOP)(float, const ParamOP &), template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p, void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input"); CheckInputTensor(input, "dgated_act_input");
...@@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO ...@@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
constexpr bool allow_empty = false; constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input"); CheckInputTensor(gated_input, "gated_input");
...@@ -1318,7 +1316,7 @@ namespace detail { ...@@ -1318,7 +1316,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
ParamOP &p, cudaStream_t stream) { ParamOP p, cudaStream_t stream) {
using namespace gated_kernels; using namespace gated_kernels;
Tensor grad_empty_tensor; Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
......
...@@ -11,7 +11,6 @@ from functools import partial ...@@ -11,7 +11,6 @@ from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize.tensor import NoScaleTensor from .quantize.tensor import NoScaleTensor
...@@ -22,6 +21,7 @@ def activation( ...@@ -22,6 +21,7 @@ def activation(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[tex.activation.ActivationParams] = None,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization. """Apply activation functions to input tensor with optional quantization.
...@@ -32,17 +32,19 @@ def activation( ...@@ -32,17 +32,19 @@ def activation(
x: Input tensor to apply activations to x: Input tensor to apply activations to
activation_type: Sequence of activation functions activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output quantizer: Optional quantizer for quantizing the output
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns: Returns:
Activated output tensor Activated output tensor
""" """
assert x.shape[-1] % len(activation_type) == 0 assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer) output = _activation(x, activation_type, quantizer, act_params)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(1,)) @partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _activation(x, activation_type, quantizer): def _activation(x, activation_type, quantizer, act_params):
"""Internal implementation of activation with custom VJP. """Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for This function implements the core activation logic with support for
...@@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer): ...@@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer):
x: Input tensor x: Input tensor
activation_type: Sequence of activation functions activation_type: Sequence of activation functions
quantizer: Optional quantizer quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns: Returns:
Activated tensor Activated tensor
""" """
_output, _ = _activation_fwd_rule(x, activation_type, quantizer) _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params)
return _output return _output
def _activation_fwd_rule(x, activation_type, quantizer): def _activation_fwd_rule(x, activation_type, quantizer, act_params):
"""Forward pass rule for activation function. """Forward pass rule for activation function.
Args: Args:
x: Input tensor x: Input tensor
activation_type: Sequence of activation functions activation_type: Sequence of activation functions
quantizer: Optional quantizer quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns: Returns:
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
fwd_output = tex.act_lu(x, activation_type, quantizer) fwd_output = tex.act_lu(x, activation_type, quantizer, act_params)
# This is a no-op for higher-precision tensors # This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize() fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer) return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g): def _activation_bwd_rule(activation_type, act_params, ctx, g):
"""Backward pass rule for activation function. """Backward pass rule for activation function.
Args: Args:
activation_type: Sequence of activation functions activation_type: Sequence of activation functions
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
ctx: Context from forward pass ctx: Context from forward pass
g: Gradient from upstream g: Gradient from upstream
...@@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g): ...@@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g):
""" """
(x, _) = ctx (x, _) = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type) dx = tex.dact_lu(g, x, activation_type, act_params=act_params)
# No quantization is used in this VJP backward, so the output should # No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor # always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor) assert isinstance(dx, NoScaleTensor)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from typing import Sequence, Union, Callable, Optional, Tuple from typing import Sequence, Union, Callable, Optional, Tuple
import operator import operator
from functools import reduce, partial from functools import reduce, partial
from dataclasses import dataclass
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -12,9 +13,9 @@ from jax import dtypes, ffi ...@@ -12,9 +13,9 @@ from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import numpy as np
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type from transformer_engine_jax import NVTE_Activation_Type
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .misc import ( from .misc import (
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
...@@ -51,17 +52,87 @@ ActivationEnum = { ...@@ -51,17 +52,87 @@ ActivationEnum = {
("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
("squared_relu",): NVTE_Activation_Type.SRELU, ("squared_relu",): NVTE_Activation_Type.SRELU,
("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU,
} }
def _convert_to_activation_function(fn_or_string): @dataclass(frozen=True)
class ClampedSwigluParams:
"""Parameters for the Clamped SwiGLU activation function
used in GPT OSS."""
limit: float = 7.0
alpha: float = 1.702
def __hash__(self):
"""Custom hash function to ensure dataclass is hashable for jax jit to work.
Returns:
int: Hash value of the dataclass instance.
"""
return hash((self.limit, self.alpha))
def to_ffi_lowering_dict(self):
"""Convert the activation parameters to a dictionary format for FFI lowering.
Returns:
dict: A dictionary representation of the activation parameters consumable by
XLA FFI bindings for activation functions.
"""
return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}
@dataclass(frozen=True)
class ActivationParams:
"""Parameters for various activation functions.
Currently only Clamped SwiGLU activation has parameters.
"""
clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams()
@staticmethod
def create(activation_type, **kwargs):
"""Factory method to create ActivationParams based on activation_type."""
CLAMPED_ACTIVATION_TYPES = {
("clamped_silu", "clamped_linear"),
"clamped_silu",
"clamped_linear",
}
if activation_type in CLAMPED_ACTIVATION_TYPES:
return ActivationParams(ClampedSwigluParams(**kwargs))
return ActivationParams() # Default params for activations without parameters
def __hash__(self):
"""Custom hash function to ensure dataclass is hashable for jax jit to work"""
return hash((self.clamped_swiglu,))
def to_ffi_lowering_dict(self):
"""Convert the activation parameters to a dictionary format for FFI lowering.
Returns:
dict: A dictionary representation of the activation parameters consumable by
XLA FFI bindings for activation functions.
"""
return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()}
def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
if fn_or_string == "linear": if fn_or_string == "linear":
return lambda x: x return lambda x: x
if fn_or_string == "clamped_linear":
# This function is used for ClampedSwiGLU
# used in GPT OSS where the gates are not only clamped
# but also shifted by +1
limit = act_params.clamped_swiglu.limit
return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
if fn_or_string == "quick_gelu": if fn_or_string == "quick_gelu":
return lambda x: jax.nn.sigmoid(1.702 * x) * x return lambda x: jax.nn.sigmoid(1.702 * x) * x
if fn_or_string == "squared_relu": if fn_or_string == "squared_relu":
return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
if fn_or_string == "clamped_silu":
limit = act_params.clamped_swiglu.limit
alpha = act_params.clamped_swiglu.alpha
return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit)
if isinstance(fn_or_string, str): if isinstance(fn_or_string, str):
return getattr(jax.nn, fn_or_string) return getattr(jax.nn, fn_or_string)
if callable(fn_or_string): if callable(fn_or_string):
...@@ -84,7 +155,8 @@ class ActLuPrimitive(BasePrimitive): ...@@ -84,7 +155,8 @@ class ActLuPrimitive(BasePrimitive):
6, 6,
7, 7,
8, 8,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer 9,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -100,11 +172,12 @@ class ActLuPrimitive(BasePrimitive): ...@@ -100,11 +172,12 @@ class ActLuPrimitive(BasePrimitive):
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer, is_outer,
act_params,
): ):
""" """
te_act_lu_p abstract te_act_lu_p abstract
""" """
del act_enum del act_enum, act_params
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
...@@ -150,6 +223,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -150,6 +223,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer, is_outer,
act_params,
): ):
""" """
te_gated_act_lu_p lowering rules te_gated_act_lu_p lowering rules
...@@ -158,9 +232,14 @@ class ActLuPrimitive(BasePrimitive): ...@@ -158,9 +232,14 @@ class ActLuPrimitive(BasePrimitive):
x_aval, scale_aval = ctx.avals_in x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
out = ffi.ffi_lowering(ActLuPrimitive.name)( out = ffi.ffi_lowering(ActLuPrimitive.name)(
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x ctx,
x,
scale,
act_enum=act_enum,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
act_params=act_params.to_ffi_lowering_dict(),
) )
return out return out
...@@ -175,6 +254,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -175,6 +254,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer, is_outer,
act_params,
): ):
""" """
to describe implementation to describe implementation
...@@ -193,6 +273,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -193,6 +273,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
is_outer=False, is_outer=False,
act_params=act_params,
) )
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
...@@ -221,6 +302,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -221,6 +302,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer, is_outer,
act_params,
): ):
""" """
to describe batch rules for vmap to describe batch rules for vmap
...@@ -242,6 +324,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -242,6 +324,7 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
act_params=act_params,
), ),
out_bdims, out_bdims,
) )
...@@ -255,6 +338,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -255,6 +338,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer, is_outer,
act_params,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -266,6 +350,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -266,6 +350,7 @@ class ActLuPrimitive(BasePrimitive):
scale_dtype, scale_dtype,
act_len, act_len,
is_outer, is_outer,
act_params,
) # Unused. ) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[1])
...@@ -318,6 +403,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -318,6 +403,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer, is_outer,
act_params,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -378,6 +464,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -378,6 +464,7 @@ class ActLuPrimitive(BasePrimitive):
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
is_outer=True, is_outer=True,
act_params=act_params,
) )
) )
...@@ -405,11 +492,12 @@ class ActLuPrimitive(BasePrimitive): ...@@ -405,11 +492,12 @@ class ActLuPrimitive(BasePrimitive):
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer, is_outer,
act_params,
mesh, mesh,
value_types, value_types,
result_types, result_types,
): ):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params
prefix = "ActLu_" prefix = "ActLu_"
input_shape = value_types[0].shape input_shape = value_types[0].shape
output_shape = input_shape[:-2] + input_shape[-1:] output_shape = input_shape[:-2] + input_shape[-1:]
...@@ -455,8 +543,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -455,8 +543,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
name = "te_dact_dbias_quantize_ffi" name = "te_dact_dbias_quantize_ffi"
multiple_results = True multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10) impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -474,11 +562,12 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -474,11 +562,12 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
is_outer, is_outer,
act_params,
): ):
""" """
te_dact_dbias_quantize_p abstract te_dact_dbias_quantize_p abstract
""" """
del act_enum del act_enum, act_params
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype assert x_aval.dtype == dz_dtype
...@@ -575,6 +664,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -575,6 +664,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
is_outer, is_outer,
act_params,
): ):
""" """
te_dact_dbias_quantize_p lowering rules te_dact_dbias_quantize_p lowering rules
...@@ -593,6 +683,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -593,6 +683,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_2x=is_2x, is_2x=is_2x,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=int(act_enum), act_enum=int(act_enum),
act_params=act_params.to_ffi_lowering_dict(),
) )
@staticmethod @staticmethod
...@@ -608,6 +699,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -608,6 +699,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
is_outer, is_outer,
act_params,
): ):
""" """
te_dact_dbias_quantize_p impl te_dact_dbias_quantize_p impl
...@@ -627,6 +719,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -627,6 +719,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
is_outer=False, is_outer=False,
act_params=act_params,
) )
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
...@@ -655,6 +748,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -655,6 +748,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
is_outer, is_outer,
act_params,
): ):
""" """
to describe batch rules for vmap to describe batch rules for vmap
...@@ -685,6 +779,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -685,6 +779,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
act_params=act_params,
), ),
out_bdims, out_bdims,
) )
...@@ -699,11 +794,12 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -699,11 +794,12 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
is_outer, is_outer,
act_params,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del out_dtype, result_infos, act_enum del out_dtype, result_infos, act_enum, act_params
del scale_dtype, act_len, is_outer del scale_dtype, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2]) scale_spec = get_padded_spec(arg_infos[2])
...@@ -774,6 +870,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -774,6 +870,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
is_outer, is_outer,
act_params,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -854,6 +951,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -854,6 +951,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
is_outer=True, is_outer=True,
act_params=act_params,
) )
) )
if is_dbias: if is_dbias:
...@@ -880,11 +978,13 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -880,11 +978,13 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum, act_enum,
act_len, act_len,
is_outer, is_outer,
act_params,
mesh, mesh,
value_types, value_types,
result_types, result_types,
): ):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params
prefix = "DActLuDBias_" prefix = "DActLuDBias_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
...@@ -923,20 +1023,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): ...@@ -923,20 +1023,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: def _jax_act_lu(
inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None
) -> Union[NoScaleTensor, ScaledTensor]:
""" """
JAX native activation implementation JAX native activation implementation
""" """
act_params = act_params if act_params is not None else ActivationParams()
act_len = len(activation_type) act_len = len(activation_type)
assert inputs.shape[-2] == act_len, ( assert inputs.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape" "activation input should be replicated by act_len in the -2 axis, got input shape"
f" {inputs.shape} and act_len {act_len}" f" {inputs.shape} and act_len {act_len}"
) )
x = jnp.split(inputs, act_len, axis=-2) x = jnp.split(inputs, act_len, axis=-2)
acts = [] acts = []
for idx, act_fn in enumerate(activation_type): for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn, act_params)(x[idx])
acts.append(x_i) acts.append(x_i)
x = reduce(operator.mul, acts) x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2) x = jnp.squeeze(x, axis=-2)
...@@ -951,10 +1053,12 @@ def _jax_quantize_dact_dbias( ...@@ -951,10 +1053,12 @@ def _jax_quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
is_dbias: bool = True, is_dbias: bool = True,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
): ):
""" """
JAX implementation of dact_lu and dbias with optional quantization JAX implementation of dact_lu and dbias with optional quantization
""" """
act_params = act_params if act_params is not None else ActivationParams()
act_len = len(activation_type) act_len = len(activation_type)
assert x.shape[-2] == act_len, ( assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape" "activation input should be replicated by act_len in the -2 axis, got input shape"
...@@ -962,7 +1066,8 @@ def _jax_quantize_dact_dbias( ...@@ -962,7 +1066,8 @@ def _jax_quantize_dact_dbias(
) )
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) partial(_jax_act_lu, activation_type=activation_type, act_params=act_params),
x.astype(jnp.float32),
) )
# VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
...@@ -985,6 +1090,7 @@ def act_lu( ...@@ -985,6 +1090,7 @@ def act_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -1008,24 +1114,22 @@ def act_lu( ...@@ -1008,24 +1114,22 @@ def act_lu(
"activation input should be replicated by act_len in the -2 axis, got input shape" "activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}" f" {x.shape} and act_len {act_len}"
) )
act_params = act_params if act_params is not None else ActivationParams()
if not ActLuPrimitive.enabled(): if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer) return _jax_act_lu(x, activation_type, quantizer, act_params)
# TE/common does not support colwise-only quantization yet # TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_act_lu(x, activation_type, quantizer) return _jax_act_lu(x, activation_type, quantizer, act_params)
# TE/common does not support 2x quantization for DelayedScaling yet # TE/common does not support 2x quantization for DelayedScaling yet
war_output = try_apply_delayed_scaling_2x_war( war_output = try_apply_delayed_scaling_2x_war(
f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params
) )
if war_output is not None: if war_output is not None:
return war_output return war_output
scale = jnp.empty((1,), jnp.float32) scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-2], x.shape[-1]) output_shape = (*x.shape[:-2], x.shape[-1])
if quantizer is None: if quantizer is None:
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
x, x,
...@@ -1037,6 +1141,7 @@ def act_lu( ...@@ -1037,6 +1141,7 @@ def act_lu(
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
is_outer=True, is_outer=True,
act_params=act_params,
) )
out = out.reshape(output_shape) out = out.reshape(output_shape)
out = NoScaleTensor( out = NoScaleTensor(
...@@ -1051,6 +1156,7 @@ def act_lu( ...@@ -1051,6 +1156,7 @@ def act_lu(
x=x, x=x,
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
act_params=act_params,
) )
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
out, out,
...@@ -1060,7 +1166,6 @@ def act_lu( ...@@ -1060,7 +1166,6 @@ def act_lu(
amax_scope=amax_scope, amax_scope=amax_scope,
) )
return out return out
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
...@@ -1080,6 +1185,7 @@ def act_lu( ...@@ -1080,6 +1185,7 @@ def act_lu(
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
is_outer=True, is_outer=True,
act_params=act_params,
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
...@@ -1102,6 +1208,7 @@ def quantize_dact_dbias( ...@@ -1102,6 +1208,7 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True, is_dbias: bool = True,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]: ) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization. """Compute gradients of activation and bias with optional quantization.
...@@ -1118,7 +1225,7 @@ def quantize_dact_dbias( ...@@ -1118,7 +1225,7 @@ def quantize_dact_dbias(
- The gradient of the activation with respect to the input. - The gradient of the activation with respect to the input.
- The gradient of the activation with respect to the bias. - The gradient of the activation with respect to the bias.
""" """
act_params = act_params if act_params is not None else ActivationParams()
act_len = len(activation_type) act_len = len(activation_type)
assert x.shape[-2] == act_len, ( assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape" "activation input should be replicated by act_len in the -2 axis, got input shape"
...@@ -1131,8 +1238,7 @@ def quantize_dact_dbias( ...@@ -1131,8 +1238,7 @@ def quantize_dact_dbias(
if not PrimitiveClass.enabled() or ( if not PrimitiveClass.enabled() or (
quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
): ):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params)
if quantizer is None: if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz, dz,
...@@ -1148,6 +1254,7 @@ def quantize_dact_dbias( ...@@ -1148,6 +1254,7 @@ def quantize_dact_dbias(
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
is_outer=True, is_outer=True,
act_params=act_params,
) )
output = output.astype(x.dtype) output = output.astype(x.dtype)
dbias = None dbias = None
...@@ -1163,7 +1270,11 @@ def quantize_dact_dbias( ...@@ -1163,7 +1270,11 @@ def quantize_dact_dbias(
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out = dact_lu( out = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None dz.astype(jnp.float32),
x.astype(jnp.float32),
activation_type,
quantizer=None,
act_params=act_params,
) )
return _quantize_dbias_impl( return _quantize_dbias_impl(
out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
...@@ -1180,6 +1291,7 @@ def quantize_dact_dbias( ...@@ -1180,6 +1291,7 @@ def quantize_dact_dbias(
is_dbias=is_dbias, is_dbias=is_dbias,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=-2, flatten_axis=-2,
act_params=act_params,
) )
if war_output is not None: if war_output is not None:
return war_output return war_output
...@@ -1191,6 +1303,7 @@ def quantize_dact_dbias( ...@@ -1191,6 +1303,7 @@ def quantize_dact_dbias(
x=x, x=x,
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
act_params=act_params,
) )
out, dbias = _quantize_dbias_impl( out, dbias = _quantize_dbias_impl(
out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
...@@ -1203,7 +1316,10 @@ def quantize_dact_dbias( ...@@ -1203,7 +1316,10 @@ def quantize_dact_dbias(
# TE/common dact_dbias_quantize does not support gated act yet # TE/common dact_dbias_quantize does not support gated act yet
if is_dbias and is_gated: if is_dbias and is_gated:
dgated = dact_lu( dgated = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type dz.astype(jnp.float32),
x.astype(jnp.float32),
activation_type=activation_type,
act_params=act_params,
) )
out, dbias = _quantize_dbias_impl( out, dbias = _quantize_dbias_impl(
dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
...@@ -1229,6 +1345,7 @@ def quantize_dact_dbias( ...@@ -1229,6 +1345,7 @@ def quantize_dact_dbias(
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
is_outer=True, is_outer=True,
act_params=act_params,
) )
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
...@@ -1257,6 +1374,7 @@ def dact_lu( ...@@ -1257,6 +1374,7 @@ def dact_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
""" """
Backward pass for activation with optional quantization. Backward pass for activation with optional quantization.
...@@ -1270,11 +1388,13 @@ def dact_lu( ...@@ -1270,11 +1388,13 @@ def dact_lu(
Returns: Returns:
The gradient of the activation with respect to the input. The gradient of the activation with respect to the input.
""" """
act_params = act_params if act_params is not None else ActivationParams()
output, _ = quantize_dact_dbias( output, _ = quantize_dact_dbias(
dz=dz, dz=dz,
x=x, x=x,
activation_type=activation_type, activation_type=activation_type,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
act_params=act_params,
) )
return output return output
...@@ -36,6 +36,15 @@ ...@@ -36,6 +36,15 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
struct ClampedSwigluConfig {
float limit;
float alpha;
};
struct ActivationConfig {
ClampedSwigluConfig clamped_swiglu;
};
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
// Activation // Activation
...@@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); ...@@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig,
::xla::ffi::StructMember<float>("limit"),
::xla::ffi::StructMember<float>("alpha"));
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
transformer_engine::jax::ActivationConfig,
::xla::ffi::StructMember<transformer_engine::jax::ClampedSwigluConfig>("clamped_swiglu"));
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
......
...@@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int) { bool is_2x_int, ActivationConfig act_params) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
...@@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
case NVTE_Activation_Type::SREGLU: case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
stream);
break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
...@@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, Result_Type amax_buf, int64_t act_enum,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
act_enum, scaling_mode, is_2x_int); act_enum, scaling_mode, is_2x_int, act_params);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
...@@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, ...@@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")); .Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
...@@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias) { int64_t act_enum, bool is_2x, bool is_dbias,
ActivationConfig act_params) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
case NVTE_Activation_Type::SREGLU: case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
swiglu_limit, swiglu_alpha, stream);
break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
...@@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias"), .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Error_Type DActLuDBiasQuantizeInitializeFFI(
Buffer_Type act_input_buf, Buffer_Type scale_buf, cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
Result_Type dbias_buf, Result_Type workspace_buf, bool is_dbias, ActivationConfig act_params) {
JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
bool is_2x, bool is_dbias) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf, act_input_buf, scale_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
...@@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, ...@@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias")); .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU) .value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU) .value("SREGLU", NVTE_Activation_Type::SREGLU)
.value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU)
.export_values(); .export_values();
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
......
...@@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase):
activations: Sequence[Union[str, Callable]], default = ('relu',) activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first dense layer transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
activation_params: dict, default = None
The parameters needed(if any) by the activation functions specified in :attr:`activations`.
At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout' intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1 intermediate_dropout_rate: float, default = 0.1
...@@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_axes_2: Tuple[str, ...] = ("embed",) bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",) activations: Sequence[Union[str, Callable]] = ("relu",)
activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = () intermediate_hidden_dropout_dims: Sequence[int] = ()
...@@ -1023,6 +1028,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1023,6 +1028,7 @@ class LayerNormMLP(TransformerEngineBase):
("relu", "linear"), ("relu", "linear"),
("quick_gelu", "linear"), ("quick_gelu", "linear"),
("squared_relu", "linear"), ("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
] ]
act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
normalized_acts = [] normalized_acts = []
...@@ -1031,7 +1037,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1031,7 +1037,9 @@ class LayerNormMLP(TransformerEngineBase):
return False return False
normalized_acts.append(act.lower()) normalized_acts.append(act.lower())
normalized_acts = tuple( normalized_acts = tuple(
reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts reversed(normalized_acts)
if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
else normalized_acts
) )
is_act_implemented = normalized_acts in (gated_act_pool + act_pool) is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
...@@ -1150,6 +1158,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1150,6 +1158,7 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name=self.ffn1_ckpt_name, ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=self.ffn2_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
) )
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
...@@ -1287,4 +1296,4 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1287,4 +1296,4 @@ class LayerNormMLP(TransformerEngineBase):
out = checkpoint_name(out, self.ffn2_ckpt_name) out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layer_norm_output
...@@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_activations: Sequence[str], default = ('relu', ) mlp_activations: Sequence[str], default = ('relu', )
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
mlp_activation_params: dict = None
This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment
ClampedSwiglu is the only activation that requires parameters.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases. If set to False, the layer will not learn additive biases.
...@@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mha_kernel_init: Initializer = None mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("relu",)
mlp_activation_params: dict = None
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
...@@ -2046,6 +2050,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2046,6 +2050,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations, activations=self.mlp_activations,
activation_params=self.mlp_activation_params,
intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rng_name=self.dropout_rng_name,
intermediate_dropout_rate=self.intermediate_dropout, intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
......
...@@ -50,6 +50,7 @@ def layernorm_mlp( ...@@ -50,6 +50,7 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
activation_params: dict = None,
collective_op_sets: Tuple[tex.CollectiveOpSet] = ( collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set, tex.noop_collective_op_set,
tex.noop_collective_op_set, tex.noop_collective_op_set,
...@@ -138,13 +139,14 @@ def layernorm_mlp( ...@@ -138,13 +139,14 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets, collective_op_sets,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -165,6 +167,7 @@ def _layernorm_mlp( ...@@ -165,6 +167,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
activation_params: dict,
collective_op_sets: Tuple[tex.CollectiveOpSet], collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets, quantizer_sets,
): ):
...@@ -220,6 +223,7 @@ def _layernorm_mlp( ...@@ -220,6 +223,7 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets, collective_op_sets,
quantizer_sets, quantizer_sets,
) )
...@@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets, collective_op_sets,
quantizer_sets, quantizer_sets,
): ):
...@@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule( ...@@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule(
dot_1_output, dot_1_output,
activation_type, activation_type,
quantizer=ffn2_quantizer_set.x, quantizer=ffn2_quantizer_set.x,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
) )
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
...@@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets, collective_op_sets,
ctx, ctx,
grad, grad,
...@@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule( ...@@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type, activation_type=activation_type,
is_dbias=use_bias_1, is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad, quantizer=ffn2_quantizer_set.dgrad,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment