Unverified Commit 127b6d3a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Activation/Normalization to output amax for later quantization in CurrentScaling (#2238)



* reuse amax for current scaling
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9f3e79bf
...@@ -543,6 +543,18 @@ class AmaxScope(Enum): ...@@ -543,6 +543,18 @@ class AmaxScope(Enum):
TPSP = 2 TPSP = 2
FSDP = 3 FSDP = 3
def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh):
"""Reduce the amax based on its scope"""
gmesh = global_mesh_resource()
sequence_dim = 0 if transpose_batch_sequence else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource:
return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if self is AmaxScope.FSDP:
return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
class AmaxCalculationPrimitive(BasePrimitive): class AmaxCalculationPrimitive(BasePrimitive):
""" """
...@@ -554,7 +566,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -554,7 +566,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
impl_static_args = ( impl_static_args = (
1, 1,
2, 2,
) # amax_scope, batch_sequence_transpose ) # amax_scope, transpose_batch_sequence
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -563,12 +575,12 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -563,12 +575,12 @@ class AmaxCalculationPrimitive(BasePrimitive):
x_aval, x_aval,
*, *,
amax_scope, amax_scope,
batch_sequence_transpose, transpose_batch_sequence,
): ):
""" """
amax calcuation abstract amax calcuation abstract
""" """
del amax_scope, batch_sequence_transpose del amax_scope, transpose_batch_sequence
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -580,19 +592,19 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -580,19 +592,19 @@ class AmaxCalculationPrimitive(BasePrimitive):
def impl( def impl(
x, x,
amax_scope, amax_scope,
batch_sequence_transpose, transpose_batch_sequence,
): ):
""" """
amax calcuation implementation amax calcuation implementation
""" """
del amax_scope, batch_sequence_transpose del amax_scope, transpose_batch_sequence
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax return amax
@staticmethod @staticmethod
def infer_sharding_from_operands( def infer_sharding_from_operands(
amax_scope, amax_scope,
batch_sequence_transpose, transpose_batch_sequence,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -600,7 +612,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -600,7 +612,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
""" """
amax calcuation infer_sharding_from_operands amax calcuation infer_sharding_from_operands
""" """
del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused. del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(None), PartitionSpec(None),
...@@ -611,7 +623,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -611,7 +623,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
@staticmethod @staticmethod
def partition( def partition(
amax_scope, amax_scope,
batch_sequence_transpose, transpose_batch_sequence,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -631,16 +643,11 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -631,16 +643,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
amax = AmaxCalculationPrimitive.impl( amax = AmaxCalculationPrimitive.impl(
x, x,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
) )
gmesh = global_mesh_resource()
sequence_dim = 0 if batch_sequence_transpose else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if amax_scope is AmaxScope.FSDP:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax return amax
...@@ -648,11 +655,11 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -648,11 +655,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
return mesh, sharded_impl, amax_sharding, arg_shardings return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod @staticmethod
def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types): def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types):
""" """
amax calcuation shardy_sharding_rule amax calcuation shardy_sharding_rule
""" """
del amax_scope, batch_sequence_transpose, mesh, result_types del amax_scope, transpose_batch_sequence, mesh, result_types
prefix = "AmaxCal" prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",) output_spec = (f"{prefix}_amax",)
...@@ -709,7 +716,7 @@ def _quantize_dbias_impl( ...@@ -709,7 +716,7 @@ def _quantize_dbias_impl(
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -755,12 +762,12 @@ def _quantize_dbias_impl( ...@@ -755,12 +762,12 @@ def _quantize_dbias_impl(
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
scale = jnp.empty((), jnp.float32) scale = jnp.empty((1,), jnp.float32)
amax = None amax = None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale. # Globally reduce amax across all devices for current scaling so we have a single global scale.
...@@ -771,7 +778,7 @@ def _quantize_dbias_impl( ...@@ -771,7 +778,7 @@ def _quantize_dbias_impl(
amax = AmaxCalculationPrimitive.outer_primitive.bind( amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data, x.data,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
...@@ -845,7 +852,7 @@ def quantize( ...@@ -845,7 +852,7 @@ def quantize(
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -866,7 +873,7 @@ def quantize( ...@@ -866,7 +873,7 @@ def quantize(
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
return out return out
...@@ -877,7 +884,7 @@ def quantize_dbias( ...@@ -877,7 +884,7 @@ def quantize_dbias(
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -904,7 +911,7 @@ def quantize_dbias( ...@@ -904,7 +911,7 @@ def quantize_dbias(
is_dbias=is_dbias, is_dbias=is_dbias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
......
...@@ -15,13 +15,14 @@ namespace transformer_engine { ...@@ -15,13 +15,14 @@ namespace transformer_engine {
namespace jax { namespace jax {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params) { bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS // parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
...@@ -30,7 +31,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -30,7 +31,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
auto m = product(input_dims, 0, input_dims.size() - 2); auto m = product(input_dims, 0, input_dims.size() - 2);
...@@ -45,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -45,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m}; auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK( NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
...@@ -55,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -55,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1}); convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
...@@ -145,26 +150,29 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -145,26 +150,29 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"), .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params) { ActivationConfig act_params, bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
act_enum, scaling_mode, is_2x_int, act_params); updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params,
output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
...@@ -172,15 +180,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, ...@@ -172,15 +180,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params")); .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
...@@ -246,15 +256,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -246,15 +256,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, Result_Type dbias_buf, Result_Type workspace_buf,
int64_t act_enum, bool is_2x, bool is_dbias, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
ActivationConfig act_params) { bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS // parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -262,7 +274,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -262,7 +274,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
...@@ -305,13 +319,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -305,13 +319,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1}); convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
...@@ -440,6 +455,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -440,6 +455,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input .Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
...@@ -451,19 +467,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -451,19 +467,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"), .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI( Error_Type DActLuDBiasQuantizeInitializeFFI(
cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
bool is_dbias, ActivationConfig act_params) { int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf, act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params); workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params,
output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
...@@ -473,18 +492,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, ...@@ -473,18 +492,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input .Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params")); .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
...@@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
} }
Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, Buffer_Type amax_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
int norm_type, bool zero_centered_gamma, double epsilon, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
...@@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data();
auto *mu = mu_buf->untyped_data(); auto *mu = mu_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data(); auto *workspace = wkspace_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type); auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x); auto _is_2x = static_cast<bool>(is_2x);
...@@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK( NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
...@@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
} }
if (_is_2x) { if (_is_2x) {
...@@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x .Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma .Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta .Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output .Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv .Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu .Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma .Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
...@@ -177,21 +185,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -177,21 +185,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type gamma_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type beta_buf, Result_Type output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type amax_buf, Result_Type mu_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin, bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x) { JAXX_Scaling_Mode scaling_mode, bool is_2x,
return wrapInStreamCapture( bool output_amax_when_no_scaling) {
std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf, return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf, gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x); colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
...@@ -199,13 +211,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ ...@@ -199,13 +211,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x .Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma .Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta .Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output .Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv .Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu .Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma .Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
...@@ -214,7 +227,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ ...@@ -214,7 +227,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")); .Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type, DType w_dtype, NVTE_Norm_Type norm_type,
......
...@@ -120,9 +120,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -120,9 +120,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) { if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax == updated_amax && amax != nullptr,
"amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
......
...@@ -63,7 +63,7 @@ def dense( ...@@ -63,7 +63,7 @@ def dense(
kernel: jnp.ndarray, kernel: jnp.ndarray,
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None, output_axes: Tuple[str, ...] = None,
...@@ -81,7 +81,7 @@ def dense( ...@@ -81,7 +81,7 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output output_axes: Logical axes for sharding the output
...@@ -91,8 +91,8 @@ def dense( ...@@ -91,8 +91,8 @@ def dense(
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
if batch_sequence_transpose: if transpose_batch_sequence:
warnings.warn("batch_sequence_transpose is not well tested, use with caution!") warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype input_dtype = x.dtype
...@@ -103,7 +103,7 @@ def dense( ...@@ -103,7 +103,7 @@ def dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -119,7 +119,7 @@ def _dense( ...@@ -119,7 +119,7 @@ def _dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -136,7 +136,7 @@ def _dense( ...@@ -136,7 +136,7 @@ def _dense(
kernel: Weight matrix kernel: Weight matrix
bias: Optional bias tensor bias: Optional bias tensor
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
...@@ -151,7 +151,7 @@ def _dense( ...@@ -151,7 +151,7 @@ def _dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -166,7 +166,7 @@ def _dense_fwd_rule( ...@@ -166,7 +166,7 @@ def _dense_fwd_rule(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -197,7 +197,7 @@ def _dense_fwd_rule( ...@@ -197,7 +197,7 @@ def _dense_fwd_rule(
flatten_axis=flatten_axis_x, flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
...@@ -215,7 +215,7 @@ def _dense_fwd_rule( ...@@ -215,7 +215,7 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward, collective_op=collective_op_set.forward,
...@@ -240,7 +240,7 @@ def _dense_fwd_rule( ...@@ -240,7 +240,7 @@ def _dense_fwd_rule(
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -274,7 +274,7 @@ def _dense_bwd_rule( ...@@ -274,7 +274,7 @@ def _dense_bwd_rule(
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
# GEMM NT # GEMM NT
...@@ -291,7 +291,7 @@ def _dense_bwd_rule( ...@@ -291,7 +291,7 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set.backward, collective_op=collective_op_set.backward,
) )
...@@ -305,7 +305,7 @@ def _dense_bwd_rule( ...@@ -305,7 +305,7 @@ def _dense_bwd_rule(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......
...@@ -432,6 +432,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -432,6 +432,8 @@ class DenseGeneral(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -446,6 +448,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -446,6 +448,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -512,6 +515,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -512,6 +515,7 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes, input_axes=self.input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -632,6 +636,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -632,6 +636,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float, default = None depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied. value or None. When None is set, then no scaling is applied.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -657,6 +663,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -657,6 +663,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None depth_scaling: float = None
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -768,6 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -768,6 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dot_input_axes=self.dot_input_axes, dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
...@@ -775,6 +783,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -775,6 +783,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
y, y,
kernel, kernel,
contracting_dims=(axis, contract_ind), contracting_dims=(axis, contract_ind),
transpose_batch_sequence=self.transpose_batch_sequence,
input_axes=self.dot_input_axes, input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
...@@ -940,6 +949,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -940,6 +949,8 @@ class LayerNormMLP(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
intermediate_dim: int = 2048 intermediate_dim: int = 2048
...@@ -974,6 +985,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -974,6 +985,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1" ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2" ffn2_ckpt_name: str = "ffn2"
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -1160,6 +1172,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1160,6 +1172,7 @@ class LayerNormMLP(TransformerEngineBase):
activation_type=normalized_acts, activation_type=normalized_acts,
activation_params=self.activation_params, activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
transpose_batch_sequence=self.transpose_batch_sequence,
) )
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
...@@ -1178,6 +1191,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1178,6 +1191,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_input_axes=self.dot_1_input_axes, dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1, kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
...@@ -1188,6 +1202,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1188,6 +1202,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_1_input_axes, input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1, kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -1260,6 +1275,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1260,6 +1275,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_2_input_axes, input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2, kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set, quantizer_set=ffn2_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
......
...@@ -1207,6 +1207,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1207,6 +1207,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="qkv", name="qkv",
dtype=self.dtype, dtype=self.dtype,
)(inputs_q) )(inputs_q)
...@@ -1234,6 +1235,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1234,6 +1235,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query", name="query",
)(inputs_q) )(inputs_q)
...@@ -1252,6 +1254,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1252,6 +1254,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_low_rank_adaptation=lora_scope.qkv_proj, enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
transpose_batch_sequence=self.transpose_batch_sequence,
name="kv", name="kv",
dtype=self.dtype, dtype=self.dtype,
)(inputs_kv) )(inputs_kv)
...@@ -1292,6 +1295,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1292,6 +1295,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query", name="query",
)(inputs_q) )(inputs_q)
...@@ -2070,6 +2074,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2070,6 +2074,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
transpose_batch_sequence=self.transpose_batch_sequence,
name="mlp", name="mlp",
)(mlp_input, deterministic=deterministic) )(mlp_input, deterministic=deterministic)
......
...@@ -16,6 +16,7 @@ import jax ...@@ -16,6 +16,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .quantize import ( from .quantize import (
QuantizerSet, QuantizerSet,
...@@ -35,6 +36,7 @@ def layernorm_dense( ...@@ -35,6 +36,7 @@ def layernorm_dense(
norm_type: str = "layernorm", norm_type: str = "layernorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
transpose_batch_sequence: bool = False,
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
...@@ -55,6 +57,7 @@ def layernorm_dense( ...@@ -55,6 +57,7 @@ def layernorm_dense(
norm_type: Type of normalization ("layernorm" or "rmsnorm") norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
...@@ -83,6 +86,7 @@ def layernorm_dense( ...@@ -83,6 +86,7 @@ def layernorm_dense(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -100,6 +104,7 @@ def layernorm_dense( ...@@ -100,6 +104,7 @@ def layernorm_dense(
8, 8,
9, 9,
10, 10,
11,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -111,6 +116,7 @@ def _layernorm_dense( ...@@ -111,6 +116,7 @@ def _layernorm_dense(
norm_type: str, norm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
transpose_batch_sequence: bool,
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...],
...@@ -131,6 +137,7 @@ def _layernorm_dense( ...@@ -131,6 +137,7 @@ def _layernorm_dense(
norm_type: Type of normalization norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for layernorm sharding layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding dot_input_axes: Logical axes for matrix multiplication sharding
quantizer_set: Set of quantizers quantizer_set: Set of quantizers
...@@ -147,6 +154,7 @@ def _layernorm_dense( ...@@ -147,6 +154,7 @@ def _layernorm_dense(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule( ...@@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule( ...@@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
...@@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule( ...@@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule(
kernel, kernel,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
quantizer=quantizer_set.kernel, quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
...@@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule( ...@@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS), casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
...@@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule( ...@@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule( ...@@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule( ...@@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel, casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
...@@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule( ...@@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule(
casted_ln_out, casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -41,7 +41,7 @@ def layernorm_mlp( ...@@ -41,7 +41,7 @@ def layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
norm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
...@@ -78,7 +78,7 @@ def layernorm_mlp( ...@@ -78,7 +78,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm") norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication
...@@ -130,7 +130,7 @@ def layernorm_mlp( ...@@ -130,7 +130,7 @@ def layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose, transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -158,7 +158,7 @@ def _layernorm_mlp( ...@@ -158,7 +158,7 @@ def _layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
batch_sequence_transpose: bool, transpose_batch_sequence: bool,
norm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
...@@ -188,7 +188,7 @@ def _layernorm_mlp( ...@@ -188,7 +188,7 @@ def _layernorm_mlp(
norm_type: Type of normalization norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding
...@@ -214,7 +214,7 @@ def _layernorm_mlp( ...@@ -214,7 +214,7 @@ def _layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose, transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -241,7 +241,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -241,7 +241,7 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose, transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -302,11 +302,16 @@ def _layernorm_mlp_fwd_rule( ...@@ -302,11 +302,16 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize( casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# NN GEMM # NN GEMM
...@@ -315,7 +320,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -315,7 +320,7 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS), casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None, bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward, collective_op=collective_op_set_1.forward,
...@@ -345,6 +350,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -345,6 +350,8 @@ def _layernorm_mlp_fwd_rule(
if activation_params if activation_params
else None else None
), ),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
...@@ -353,6 +360,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -353,6 +360,7 @@ def _layernorm_mlp_fwd_rule(
kernel_2, kernel_2,
quantizer=ffn2_quantizer_set.kernel, quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP, amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# NN GEMM # NN GEMM
...@@ -361,7 +369,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -361,7 +369,7 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None, bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward, collective_op=collective_op_set_2.forward,
...@@ -403,7 +411,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -403,7 +411,7 @@ def _layernorm_mlp_bwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose, transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -465,6 +473,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -465,6 +473,7 @@ def _layernorm_mlp_bwd_rule(
is_dbias=use_bias_2, is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad, quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -482,7 +491,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -482,7 +491,7 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2, casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_2.backward, collective_op=collective_op_set_2.backward,
) )
...@@ -498,7 +507,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -498,7 +507,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out, casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -513,6 +522,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -513,6 +522,8 @@ def _layernorm_mlp_bwd_rule(
if activation_params if activation_params
else None else None
), ),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -530,7 +541,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -530,7 +541,7 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS), casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1, casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_1.backward, collective_op=collective_op_set_1.backward,
) )
...@@ -542,7 +553,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -542,7 +553,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out, casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS), casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment