Unverified Commit b840898b authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[JAX] Clamped Swiglu Integration (#2194)


Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
*Jax integration for clamped swiglu. This is the continuation of PR which added Clamped Swiglu(used in GPT OSS) support in TE along with Pytorch integration. This PR hooks up the clamped swiglu and dswiglu's nvte APIs to TE Jax.
parent e30c36a3
......@@ -170,6 +170,7 @@ ALL_ACTIVATION_TYPES = [
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
]
ACTIVATION_TYPES = {
......@@ -182,17 +183,21 @@ ACTIVATION_TYPES = {
class TestActivation:
def ref_act(self, x, activation_type):
return _jax_act_lu(x, activation_type).data
def ref_act(self, x, activation_type, act_params):
return _jax_act_lu(x, activation_type, act_params=act_params).data
def value_n_grad_ref_func(self, x, activation_type):
def value_n_grad_ref_func(self, x, activation_type, act_params):
jitted_reference = jit(
value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
value_and_grad(
lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
)
)
return jitted_reference(x)
def primitive_func(self, inputs, activation_type, quantizer):
out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
def primitive_func(self, inputs, activation_type, quantizer, act_params):
out = activation(
inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
return jnp.mean(out)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
......@@ -209,12 +214,20 @@ class TestActivation:
x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
......@@ -234,7 +247,8 @@ class TestActivation:
self.activation_type = activation_type
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)),
static_argnums=(1, 3),
)
quantizer = QuantizerFactory.create(
......@@ -242,9 +256,21 @@ class TestActivation:
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(
x, activation_type, quantizer, act_params
)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type)
......@@ -273,10 +299,18 @@ class TestActivation:
q_dtype=output_type,
q_layout=q_layout,
)
te_output = tex.act_lu(x, activation_type, te_quantizer)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
......@@ -296,10 +330,18 @@ class TestActivation:
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)
output = tex.act_lu(x, activation_type, quantizer)
ref_out = self.ref_act(x, activation_type)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
output = tex.act_lu(x, activation_type, quantizer, act_params)
ref_out = self.ref_act(x, activation_type, act_params)
assert_dequantized_scaled_tensor(output, ref_out)
......@@ -734,6 +776,7 @@ class TestFusedQuantize:
def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
......@@ -785,7 +828,7 @@ class TestFusedQuantize:
(in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
or (
activation_type == ("squared_relu",)
activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
and in_dtype == jnp.bfloat16
and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
)
......
......@@ -39,6 +39,7 @@ enum class NVTE_Activation_Type {
QGEGLU,
SRELU,
SREGLU,
CLAMPED_SWIGLU
};
/*! \brief Computes the GeLU activation of the input.
......
......@@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
checkCuDriverContext(stream);
......@@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
checkCuDriverContext(stream);
......@@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
case ScalingType::COLWISE:
......@@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
......@@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) {
void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
......@@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p,
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
......@@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input");
......@@ -1318,7 +1316,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
ParamOP &p, cudaStream_t stream) {
ParamOP p, cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
......
......@@ -11,7 +11,6 @@ from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import NoScaleTensor
......@@ -22,6 +21,7 @@ def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
act_params: Optional[tex.activation.ActivationParams] = None,
) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization.
......@@ -32,17 +32,19 @@ def activation(
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
output = _activation(x, activation_type, quantizer, act_params)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _activation(x, activation_type, quantizer, act_params):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
......@@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer):
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
_output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params)
return _output
def _activation_fwd_rule(x, activation_type, quantizer):
def _activation_fwd_rule(x, activation_type, quantizer, act_params):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
fwd_output = tex.act_lu(x, activation_type, quantizer, act_params)
# This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g):
def _activation_bwd_rule(activation_type, act_params, ctx, g):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
ctx: Context from forward pass
g: Gradient from upstream
......@@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g):
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = tex.dact_lu(g, x, activation_type, act_params=act_params)
# No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor)
......
......@@ -36,6 +36,15 @@
namespace transformer_engine {
namespace jax {
struct ClampedSwigluConfig {
float limit;
float alpha;
};
struct ActivationConfig {
ClampedSwigluConfig clamped_swiglu;
};
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
// Activation
......@@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax
} // namespace transformer_engine
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig,
::xla::ffi::StructMember<float>("limit"),
::xla::ffi::StructMember<float>("alpha"));
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
transformer_engine::jax::ActivationConfig,
::xla::ffi::StructMember<transformer_engine::jax::ClampedSwigluConfig>("clamped_swiglu"));
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
......
......@@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int) {
bool is_2x_int, ActivationConfig act_params) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
......@@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
......@@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"),
.Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"),
FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int) {
JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
act_enum, scaling_mode, is_2x_int);
act_enum, scaling_mode, is_2x_int, act_params);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
......@@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
.Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
......@@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias) {
int64_t act_enum, bool is_2x, bool is_dbias,
ActivationConfig act_params) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
swiglu_limit, swiglu_alpha, stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
......@@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias"),
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"),
FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
bool is_2x, bool is_dbias) {
Error_Type DActLuDBiasQuantizeInitializeFFI(
cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
bool is_dbias, ActivationConfig act_params) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias);
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
......@@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias"));
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"));
} // namespace jax
} // namespace transformer_engine
......@@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU)
.value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU)
.export_values();
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
......
......@@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase):
activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer.
activation_params: dict, default = None
The parameters needed(if any) by the activation functions specified in :attr:`activations`.
At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1
......@@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",)
activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
......@@ -1023,6 +1028,7 @@ class LayerNormMLP(TransformerEngineBase):
("relu", "linear"),
("quick_gelu", "linear"),
("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
]
act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
normalized_acts = []
......@@ -1031,7 +1037,9 @@ class LayerNormMLP(TransformerEngineBase):
return False
normalized_acts.append(act.lower())
normalized_acts = tuple(
reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts
reversed(normalized_acts)
if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
else normalized_acts
)
is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
......@@ -1150,6 +1158,7 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts,
activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
)
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
......@@ -1287,4 +1296,4 @@ class LayerNormMLP(TransformerEngineBase):
out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output
return out, ln_output # Output, layer_norm_output
......@@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_activations: Sequence[str], default = ('relu', )
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
mlp_activation_params: dict = None
This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment
ClampedSwiglu is the only activation that requires parameters.
use_bias: bool, default = False
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases.
......@@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",)
mlp_activation_params: dict = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
......@@ -2046,6 +2050,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations,
activation_params=self.mlp_activation_params,
intermediate_dropout_rng_name=self.dropout_rng_name,
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
......
......@@ -50,6 +50,7 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
activation_params: dict = None,
collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set,
tex.noop_collective_op_set,
......@@ -138,13 +139,14 @@ def layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
activation_params,
collective_op_sets,
quantizer_sets,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -165,6 +167,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
activation_params: dict,
collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets,
):
......@@ -220,6 +223,7 @@ def _layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
activation_params,
collective_op_sets,
quantizer_sets,
)
......@@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
activation_params,
collective_op_sets,
quantizer_sets,
):
......@@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule(
dot_1_output,
activation_type,
quantizer=ffn2_quantizer_set.x,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
......@@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
activation_params,
collective_op_sets,
ctx,
grad,
......@@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type,
is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment