Unverified Commit b840898b authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[JAX] Clamped Swiglu Integration (#2194)


Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
*Jax integration for clamped swiglu. This is the continuation of PR which added Clamped Swiglu(used in GPT OSS) support in TE along with Pytorch integration. This PR hooks up the clamped swiglu and dswiglu's nvte APIs to TE Jax.
parent e30c36a3
...@@ -170,6 +170,7 @@ ALL_ACTIVATION_TYPES = [ ...@@ -170,6 +170,7 @@ ALL_ACTIVATION_TYPES = [
("quick_gelu", "linear"), ("quick_gelu", "linear"),
("squared_relu",), ("squared_relu",),
("squared_relu", "linear"), ("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
] ]
ACTIVATION_TYPES = { ACTIVATION_TYPES = {
...@@ -182,17 +183,21 @@ ACTIVATION_TYPES = { ...@@ -182,17 +183,21 @@ ACTIVATION_TYPES = {
class TestActivation: class TestActivation:
def ref_act(self, x, activation_type): def ref_act(self, x, activation_type, act_params):
return _jax_act_lu(x, activation_type).data return _jax_act_lu(x, activation_type, act_params=act_params).data
def value_n_grad_ref_func(self, x, activation_type): def value_n_grad_ref_func(self, x, activation_type, act_params):
jitted_reference = jit( jitted_reference = jit(
value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) value_and_grad(
lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
)
) )
return jitted_reference(x) return jitted_reference(x)
def primitive_func(self, inputs, activation_type, quantizer): def primitive_func(self, inputs, activation_type, quantizer, act_params):
out = activation(inputs, activation_type=activation_type, quantizer=quantizer) out = activation(
inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
return jnp.mean(out) return jnp.mean(out)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
...@@ -209,12 +214,20 @@ class TestActivation: ...@@ -209,12 +214,20 @@ class TestActivation:
x = jnp.repeat(x, len(activation_type), axis=-2) x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
) )
act_args = (
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) {"limit": 0.75, "alpha": 1.702}
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
...@@ -234,7 +247,8 @@ class TestActivation: ...@@ -234,7 +247,8 @@ class TestActivation:
self.activation_type = activation_type self.activation_type = activation_type
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) value_and_grad(self.primitive_func, (0,)),
static_argnums=(1, 3),
) )
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
...@@ -242,9 +256,21 @@ class TestActivation: ...@@ -242,9 +256,21 @@ class TestActivation:
q_dtype=output_type, q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) act_params = (
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(
x, activation_type, quantizer, act_params
)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type)
...@@ -273,10 +299,18 @@ class TestActivation: ...@@ -273,10 +299,18 @@ class TestActivation:
q_dtype=output_type, q_dtype=output_type,
q_layout=q_layout, q_layout=q_layout,
) )
act_args = (
te_output = tex.act_lu(x, activation_type, te_quantizer) {"limit": 0.75, "alpha": 1.702}
jax_output = _jax_act_lu(x, activation_type, jax_quantizer) if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
assert_bitwise_scaled_tensors(te_output, jax_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
...@@ -296,10 +330,18 @@ class TestActivation: ...@@ -296,10 +330,18 @@ class TestActivation:
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
) )
act_args = (
output = tex.act_lu(x, activation_type, quantizer) {"limit": 0.75, "alpha": 1.702}
ref_out = self.ref_act(x, activation_type) if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
output = tex.act_lu(x, activation_type, quantizer, act_params)
ref_out = self.ref_act(x, activation_type, act_params)
assert_dequantized_scaled_tensor(output, ref_out) assert_dequantized_scaled_tensor(output, ref_out)
...@@ -734,6 +776,7 @@ class TestFusedQuantize: ...@@ -734,6 +776,7 @@ class TestFusedQuantize:
def _test_quantize_dact_dbias( def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
): ):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
...@@ -785,7 +828,7 @@ class TestFusedQuantize: ...@@ -785,7 +828,7 @@ class TestFusedQuantize:
(in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
or ( or (
activation_type == ("squared_relu",) activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
and in_dtype == jnp.bfloat16 and in_dtype == jnp.bfloat16
and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
) )
......
...@@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { ...@@ -39,6 +39,7 @@ enum class NVTE_Activation_Type {
QGEGLU, QGEGLU,
SRELU, SRELU,
SREGLU, SREGLU,
CLAMPED_SWIGLU
}; };
/*! \brief Computes the GeLU activation of the input. /*! \brief Computes the GeLU activation of the input.
......
...@@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
...@@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
...@@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
} }
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input"); CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output"); CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.flat_last_dim() % 2 == 0, NVTE_CHECK(input.flat_last_dim() % 2 == 0,
...@@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st ...@@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st
template <typename ParamOP, float (*ActOP)(float, const ParamOP &), template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p, void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input"); CheckInputTensor(input, "dgated_act_input");
...@@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO ...@@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
constexpr bool allow_empty = false; constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input"); CheckInputTensor(gated_input, "gated_input");
...@@ -1318,7 +1316,7 @@ namespace detail { ...@@ -1318,7 +1316,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
ParamOP &p, cudaStream_t stream) { ParamOP p, cudaStream_t stream) {
using namespace gated_kernels; using namespace gated_kernels;
Tensor grad_empty_tensor; Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
......
...@@ -11,7 +11,6 @@ from functools import partial ...@@ -11,7 +11,6 @@ from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize.tensor import NoScaleTensor from .quantize.tensor import NoScaleTensor
...@@ -22,6 +21,7 @@ def activation( ...@@ -22,6 +21,7 @@ def activation(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[tex.activation.ActivationParams] = None,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization. """Apply activation functions to input tensor with optional quantization.
...@@ -32,17 +32,19 @@ def activation( ...@@ -32,17 +32,19 @@ def activation(
x: Input tensor to apply activations to x: Input tensor to apply activations to
activation_type: Sequence of activation functions activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output quantizer: Optional quantizer for quantizing the output
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns: Returns:
Activated output tensor Activated output tensor
""" """
assert x.shape[-1] % len(activation_type) == 0 assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer) output = _activation(x, activation_type, quantizer, act_params)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(1,)) @partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _activation(x, activation_type, quantizer): def _activation(x, activation_type, quantizer, act_params):
"""Internal implementation of activation with custom VJP. """Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for This function implements the core activation logic with support for
...@@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer): ...@@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer):
x: Input tensor x: Input tensor
activation_type: Sequence of activation functions activation_type: Sequence of activation functions
quantizer: Optional quantizer quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns: Returns:
Activated tensor Activated tensor
""" """
_output, _ = _activation_fwd_rule(x, activation_type, quantizer) _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params)
return _output return _output
def _activation_fwd_rule(x, activation_type, quantizer): def _activation_fwd_rule(x, activation_type, quantizer, act_params):
"""Forward pass rule for activation function. """Forward pass rule for activation function.
Args: Args:
x: Input tensor x: Input tensor
activation_type: Sequence of activation functions activation_type: Sequence of activation functions
quantizer: Optional quantizer quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
Returns: Returns:
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
fwd_output = tex.act_lu(x, activation_type, quantizer) fwd_output = tex.act_lu(x, activation_type, quantizer, act_params)
# This is a no-op for higher-precision tensors # This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize() fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer) return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g): def _activation_bwd_rule(activation_type, act_params, ctx, g):
"""Backward pass rule for activation function. """Backward pass rule for activation function.
Args: Args:
activation_type: Sequence of activation functions activation_type: Sequence of activation functions
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
ctx: Context from forward pass ctx: Context from forward pass
g: Gradient from upstream g: Gradient from upstream
...@@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g): ...@@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g):
""" """
(x, _) = ctx (x, _) = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type) dx = tex.dact_lu(g, x, activation_type, act_params=act_params)
# No quantization is used in this VJP backward, so the output should # No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor # always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor) assert isinstance(dx, NoScaleTensor)
......
...@@ -36,6 +36,15 @@ ...@@ -36,6 +36,15 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
struct ClampedSwigluConfig {
float limit;
float alpha;
};
struct ActivationConfig {
ClampedSwigluConfig clamped_swiglu;
};
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
// Activation // Activation
...@@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); ...@@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig,
::xla::ffi::StructMember<float>("limit"),
::xla::ffi::StructMember<float>("alpha"));
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
transformer_engine::jax::ActivationConfig,
::xla::ffi::StructMember<transformer_engine::jax::ClampedSwigluConfig>("clamped_swiglu"));
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
......
...@@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int) { bool is_2x_int, ActivationConfig act_params) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
...@@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
case NVTE_Activation_Type::SREGLU: case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
stream);
break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
...@@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, Result_Type amax_buf, int64_t act_enum,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
act_enum, scaling_mode, is_2x_int); act_enum, scaling_mode, is_2x_int, act_params);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
...@@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, ...@@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")); .Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
...@@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias) { int64_t act_enum, bool is_2x, bool is_dbias,
ActivationConfig act_params) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
case NVTE_Activation_Type::SREGLU: case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break; break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
swiglu_limit, swiglu_alpha, stream);
break;
default: default:
NVTE_ERROR("Unsupported ActivationEnum"); NVTE_ERROR("Unsupported ActivationEnum");
break; break;
...@@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias"), .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Error_Type DActLuDBiasQuantizeInitializeFFI(
Buffer_Type act_input_buf, Buffer_Type scale_buf, cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
Result_Type dbias_buf, Result_Type workspace_buf, bool is_dbias, ActivationConfig act_params) {
JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
bool is_2x, bool is_dbias) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf, act_input_buf, scale_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
...@@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, ...@@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias")); .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU) .value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU) .value("SREGLU", NVTE_Activation_Type::SREGLU)
.value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU)
.export_values(); .export_values();
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
......
...@@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase):
activations: Sequence[Union[str, Callable]], default = ('relu',) activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first dense layer transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
activation_params: dict, default = None
The parameters needed(if any) by the activation functions specified in :attr:`activations`.
At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS
need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout' intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1 intermediate_dropout_rate: float, default = 0.1
...@@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_axes_2: Tuple[str, ...] = ("embed",) bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",) activations: Sequence[Union[str, Callable]] = ("relu",)
activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = () intermediate_hidden_dropout_dims: Sequence[int] = ()
...@@ -1023,6 +1028,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1023,6 +1028,7 @@ class LayerNormMLP(TransformerEngineBase):
("relu", "linear"), ("relu", "linear"),
("quick_gelu", "linear"), ("quick_gelu", "linear"),
("squared_relu", "linear"), ("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
] ]
act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)]
normalized_acts = [] normalized_acts = []
...@@ -1031,7 +1037,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1031,7 +1037,9 @@ class LayerNormMLP(TransformerEngineBase):
return False return False
normalized_acts.append(act.lower()) normalized_acts.append(act.lower())
normalized_acts = tuple( normalized_acts = tuple(
reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts reversed(normalized_acts)
if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear")
else normalized_acts
) )
is_act_implemented = normalized_acts in (gated_act_pool + act_pool) is_act_implemented = normalized_acts in (gated_act_pool + act_pool)
...@@ -1150,6 +1158,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1150,6 +1158,7 @@ class LayerNormMLP(TransformerEngineBase):
ffn1_ckpt_name=self.ffn1_ckpt_name, ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=self.ffn2_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
) )
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
...@@ -1287,4 +1296,4 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1287,4 +1296,4 @@ class LayerNormMLP(TransformerEngineBase):
out = checkpoint_name(out, self.ffn2_ckpt_name) out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layer_norm_output
...@@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_activations: Sequence[str], default = ('relu', ) mlp_activations: Sequence[str], default = ('relu', )
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
mlp_activation_params: dict = None
This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment
ClampedSwiglu is the only activation that requires parameters.
use_bias: bool, default = False use_bias: bool, default = False
Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2.
If set to False, the layer will not learn additive biases. If set to False, the layer will not learn additive biases.
...@@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mha_kernel_init: Initializer = None mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("relu",)
mlp_activation_params: dict = None
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
...@@ -2046,6 +2050,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2046,6 +2050,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations, activations=self.mlp_activations,
activation_params=self.mlp_activation_params,
intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rng_name=self.dropout_rng_name,
intermediate_dropout_rate=self.intermediate_dropout, intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
......
...@@ -50,6 +50,7 @@ def layernorm_mlp( ...@@ -50,6 +50,7 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
activation_params: dict = None,
collective_op_sets: Tuple[tex.CollectiveOpSet] = ( collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set, tex.noop_collective_op_set,
tex.noop_collective_op_set, tex.noop_collective_op_set,
...@@ -138,13 +139,14 @@ def layernorm_mlp( ...@@ -138,13 +139,14 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets, collective_op_sets,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -165,6 +167,7 @@ def _layernorm_mlp( ...@@ -165,6 +167,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
activation_params: dict,
collective_op_sets: Tuple[tex.CollectiveOpSet], collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets, quantizer_sets,
): ):
...@@ -220,6 +223,7 @@ def _layernorm_mlp( ...@@ -220,6 +223,7 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets, collective_op_sets,
quantizer_sets, quantizer_sets,
) )
...@@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets, collective_op_sets,
quantizer_sets, quantizer_sets,
): ):
...@@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule( ...@@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule(
dot_1_output, dot_1_output,
activation_type, activation_type,
quantizer=ffn2_quantizer_set.x, quantizer=ffn2_quantizer_set.x,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
) )
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
...@@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
activation_params,
collective_op_sets, collective_op_sets,
ctx, ctx,
grad, grad,
...@@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule( ...@@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type, activation_type=activation_type,
is_dbias=use_bias_1, is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad, quantizer=ffn2_quantizer_set.dgrad,
act_params=(
tex.activation.ActivationParams.create(activation_type, **activation_params)
if activation_params
else None
),
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment