Unverified Commit 127b6d3a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Activation/Normalization to output amax for later quantization in CurrentScaling (#2238)



* reuse amax for current scaling
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9f3e79bf
......@@ -543,6 +543,18 @@ class AmaxScope(Enum):
TPSP = 2
FSDP = 3
def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh):
"""Reduce the amax based on its scope"""
gmesh = global_mesh_resource()
sequence_dim = 0 if transpose_batch_sequence else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource:
return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if self is AmaxScope.FSDP:
return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
class AmaxCalculationPrimitive(BasePrimitive):
"""
......@@ -554,7 +566,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
impl_static_args = (
1,
2,
) # amax_scope, batch_sequence_transpose
) # amax_scope, transpose_batch_sequence
inner_primitive = None
outer_primitive = None
......@@ -563,12 +575,12 @@ class AmaxCalculationPrimitive(BasePrimitive):
x_aval,
*,
amax_scope,
batch_sequence_transpose,
transpose_batch_sequence,
):
"""
amax calcuation abstract
"""
del amax_scope, batch_sequence_transpose
del amax_scope, transpose_batch_sequence
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -580,19 +592,19 @@ class AmaxCalculationPrimitive(BasePrimitive):
def impl(
x,
amax_scope,
batch_sequence_transpose,
transpose_batch_sequence,
):
"""
amax calcuation implementation
"""
del amax_scope, batch_sequence_transpose
del amax_scope, transpose_batch_sequence
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
batch_sequence_transpose,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
......@@ -600,7 +612,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused.
del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
......@@ -611,7 +623,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
@staticmethod
def partition(
amax_scope,
batch_sequence_transpose,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
......@@ -631,16 +643,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
)
gmesh = global_mesh_resource()
sequence_dim = 0 if batch_sequence_transpose else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if amax_scope is AmaxScope.FSDP:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
......@@ -648,11 +655,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types):
def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, batch_sequence_transpose, mesh, result_types
del amax_scope, transpose_batch_sequence, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
......@@ -709,7 +716,7 @@ def _quantize_dbias_impl(
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -755,12 +762,12 @@ def _quantize_dbias_impl(
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
scale = jnp.empty((), jnp.float32)
scale = jnp.empty((1,), jnp.float32)
amax = None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
......@@ -771,7 +778,7 @@ def _quantize_dbias_impl(
amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
......@@ -845,7 +852,7 @@ def quantize(
quantizer: Quantizer,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -866,7 +873,7 @@ def quantize(
quantizer=quantizer,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
return out
......@@ -877,7 +884,7 @@ def quantize_dbias(
is_dbias: bool = True,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -904,7 +911,7 @@ def quantize_dbias(
is_dbias=is_dbias,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
......
......@@ -15,13 +15,14 @@ namespace transformer_engine {
namespace jax {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params) {
Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
......@@ -30,7 +31,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto *output = output_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto input_dims = input_buf.dimensions();
auto m = product(input_dims, 0, input_dims.size() - 2);
......@@ -45,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
......@@ -55,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
......@@ -145,26 +150,29 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"),
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
act_enum, scaling_mode, is_2x_int, act_params);
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params, bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params,
output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
......@@ -172,15 +180,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"));
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
......@@ -246,15 +256,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias,
ActivationConfig act_params) {
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -262,7 +274,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
......@@ -305,13 +319,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
......@@ -440,6 +455,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
......@@ -451,19 +467,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"),
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI(
cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
bool is_dbias, ActivationConfig act_params) {
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params);
act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params,
output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
......@@ -473,18 +492,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"));
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
} // namespace jax
} // namespace transformer_engine
......@@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
......@@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
}
Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf,
int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) {
Buffer_Type amax_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
......@@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *output = output_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *mu = mu_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
......@@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
......@@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
if (_is_2x) {
......@@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
......@@ -177,21 +185,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"),
.Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type,
Buffer_Type amax_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
return wrapInStreamCapture(
std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x);
JAXX_Scaling_Mode scaling_mode, bool is_2x,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
......@@ -199,13 +211,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
......@@ -214,7 +227,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
.Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type,
......
......@@ -120,9 +120,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
NVTE_CHECK(amax == updated_amax && amax != nullptr,
"amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
......
......@@ -63,7 +63,7 @@ def dense(
kernel: jnp.ndarray,
bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None,
......@@ -81,7 +81,7 @@ def dense(
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
......@@ -91,8 +91,8 @@ def dense(
Returns:
Transformed output tensor
"""
if batch_sequence_transpose:
warnings.warn("batch_sequence_transpose is not well tested, use with caution!")
if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
......@@ -103,7 +103,7 @@ def dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -119,7 +119,7 @@ def _dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -136,7 +136,7 @@ def _dense(
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix
......@@ -151,7 +151,7 @@ def _dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -166,7 +166,7 @@ def _dense_fwd_rule(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -197,7 +197,7 @@ def _dense_fwd_rule(
flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
......@@ -215,7 +215,7 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward,
......@@ -240,7 +240,7 @@ def _dense_fwd_rule(
def _dense_bwd_rule(
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -274,7 +274,7 @@ def _dense_bwd_rule(
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
# GEMM NT
......@@ -291,7 +291,7 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set.backward,
)
......@@ -305,7 +305,7 @@ def _dense_bwd_rule(
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......
......@@ -432,6 +432,8 @@ class DenseGeneral(TransformerEngineBase):
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
"""
features: Union[Iterable[int], int]
......@@ -446,6 +448,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
input_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
def __post_init__(self):
if self.kernel_init is None:
......@@ -512,6 +515,7 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
if self.enable_low_rank_adaptation:
......@@ -632,6 +636,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
"""
features: Union[Iterable[int], int]
......@@ -657,6 +663,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
transpose_batch_sequence: bool = False
def __post_init__(self):
if self.kernel_init is None:
......@@ -768,6 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
......@@ -775,6 +783,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
y,
kernel,
contracting_dims=(axis, contract_ind),
transpose_batch_sequence=self.transpose_batch_sequence,
input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
......@@ -940,6 +949,8 @@ class LayerNormMLP(TransformerEngineBase):
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
"""
intermediate_dim: int = 2048
......@@ -974,6 +985,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2"
transpose_batch_sequence: bool = False
def __post_init__(self):
if self.kernel_init is None:
......@@ -1160,6 +1172,7 @@ class LayerNormMLP(TransformerEngineBase):
activation_type=normalized_acts,
activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
transpose_batch_sequence=self.transpose_batch_sequence,
)
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
......@@ -1178,6 +1191,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
......@@ -1188,6 +1202,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
if self.enable_low_rank_adaptation:
......@@ -1260,6 +1275,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
if self.enable_low_rank_adaptation:
......
......@@ -1207,6 +1207,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="qkv",
dtype=self.dtype,
)(inputs_q)
......@@ -1234,6 +1235,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query",
)(inputs_q)
......@@ -1252,6 +1254,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
transpose_batch_sequence=self.transpose_batch_sequence,
name="kv",
dtype=self.dtype,
)(inputs_kv)
......@@ -1292,6 +1295,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query",
)(inputs_q)
......@@ -2070,6 +2074,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
transpose_batch_sequence=self.transpose_batch_sequence,
name="mlp",
)(mlp_input, deterministic=deterministic)
......
......@@ -16,6 +16,7 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .quantize import (
QuantizerSet,
......@@ -35,6 +36,7 @@ def layernorm_dense(
norm_type: str = "layernorm",
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
transpose_batch_sequence: bool = False,
layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
......@@ -55,6 +57,7 @@ def layernorm_dense(
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
......@@ -83,6 +86,7 @@ def layernorm_dense(
norm_type,
zero_centered_gamma,
epsilon,
transpose_batch_sequence,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
......@@ -100,6 +104,7 @@ def layernorm_dense(
8,
9,
10,
11,
),
)
def _layernorm_dense(
......@@ -111,6 +116,7 @@ def _layernorm_dense(
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
transpose_batch_sequence: bool,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
......@@ -131,6 +137,7 @@ def _layernorm_dense(
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
quantizer_set: Set of quantizers
......@@ -147,6 +154,7 @@ def _layernorm_dense(
norm_type,
zero_centered_gamma,
epsilon,
transpose_batch_sequence,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
......@@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
transpose_batch_sequence,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
......@@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule(
epsilon,
norm_type,
quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
......@@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule(
kernel,
flatten_axis=flatten_axis,
quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
......@@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
......@@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
transpose_batch_sequence,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
......@@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......@@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
......@@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule(
casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
......@@ -41,7 +41,7 @@ def layernorm_mlp(
norm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
......@@ -78,7 +78,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
......@@ -130,7 +130,7 @@ def layernorm_mlp(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
transpose_batch_sequence,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -158,7 +158,7 @@ def _layernorm_mlp(
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
batch_sequence_transpose: bool,
transpose_batch_sequence: bool,
norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
......@@ -188,7 +188,7 @@ def _layernorm_mlp(
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
......@@ -214,7 +214,7 @@ def _layernorm_mlp(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
transpose_batch_sequence,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -241,7 +241,7 @@ def _layernorm_mlp_fwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
transpose_batch_sequence,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -302,11 +302,16 @@ def _layernorm_mlp_fwd_rule(
norm_type,
quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
)
# NN GEMM
......@@ -315,7 +320,7 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward,
......@@ -345,6 +350,8 @@ def _layernorm_mlp_fwd_rule(
if activation_params
else None
),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
......@@ -353,6 +360,7 @@ def _layernorm_mlp_fwd_rule(
kernel_2,
quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
)
# NN GEMM
......@@ -361,7 +369,7 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward,
......@@ -403,7 +411,7 @@ def _layernorm_mlp_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
transpose_batch_sequence,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -465,6 +473,7 @@ def _layernorm_mlp_bwd_rule(
is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......@@ -482,7 +491,7 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_2.backward,
)
......@@ -498,7 +507,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
......@@ -513,6 +522,8 @@ def _layernorm_mlp_bwd_rule(
if activation_params
else None
),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......@@ -530,7 +541,7 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_1.backward,
)
......@@ -542,7 +553,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment