Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "utils.h" #include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D ...@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x);
// Normalization // Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
...@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); ...@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype, DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode, NVTE_Norm_Type norm_type,
JAXX_Scaling_Mode scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin, bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training); bool is_training);
...@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); ...@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype); DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode,
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); QuantizeLayout q_layout);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x);
// Softmax // Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
......
...@@ -11,21 +11,13 @@ ...@@ -11,21 +11,13 @@
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
namespace {
bool is_gated(NVTE_Activation_Type act_type) {
return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU ||
act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU ||
act_type == NVTE_Activation_Type::SREGLU;
}
} // namespace
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int) { bool is_2x_int) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
...@@ -42,40 +34,59 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -42,40 +34,59 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto n = input_dims.back(); auto n = input_dims.back();
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2]; auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int); auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, act_len * n}; auto input_shape = std::vector<size_t>{m, act_len * n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv( if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
scale_inv_buf->untyped_data(), NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
std::vector<size_t>{ cudaMemsetAsync(amax, 0, sizeof(float), stream);
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
scale_inv_buf->dimensions().back()}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
} output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); } else {
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); output_tensor.set_rowwise_scale_inv(
cudaMemsetAsync(amax, 0, sizeof(float), stream); scale_inv_buf->untyped_data(),
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(colwise_output, static_cast<DType>(out_dtype), output_shape); auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
output_tensor.set_columnwise_scale_inv( ? output_trans_shape
colwise_scale_inv_buf->untyped_data(), : output_shape;
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1), if (is_fp8_dtype(out_dtype)) {
colwise_scale_inv_buf->dimensions().back()}); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? scale_inv_buf
: colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
} }
switch (act_type) { switch (act_type) {
...@@ -128,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -128,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<int64_t>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x) { JAXX_Scaling_Mode scaling_mode, bool is_2x) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size}; auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
...@@ -153,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -153,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
auto dact_input_tensor = auto dact_input_tensor =
TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype); TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(static_cast<NVTEScalingMode>(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape); output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -162,8 +173,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -162,8 +173,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
output_trans_shape); : output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -172,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -172,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
} }
} }
if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) { if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32, output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32, output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
...@@ -190,22 +202,25 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -190,22 +202,25 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
bool is_dbias, int64_t act_enum) { int64_t act_enum, bool is_2x, bool is_dbias) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data();
void *workspace = workspace_buf->untyped_data(); void *workspace = workspace_buf->untyped_data();
...@@ -213,67 +228,76 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -213,67 +228,76 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto act_input_dims = act_input_buf.dimensions(); auto act_input_dims = act_input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions(); auto workspace_dims = workspace_buf->dimensions();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto input_ranks = input_dims.size(); auto act_len = act_input_dims[act_input_dims.size() - 2];
auto act_input_ranks = act_input_dims.size(); NVTE_CHECK(act_input_dims.back() == input_dims.back(),
auto m = product(act_input_dims, 0, act_input_dims.size() - 1); "Shape mismatch between activation input and gradient input");
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = act_input_dims.back(); auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, input_dims.back()};
auto act_input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_trans_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n * act_len};
auto dbias_shape = std::vector<size_t>{n}; auto output_trans_shape = std::vector<size_t>{n * act_len, m};
auto dbias_shape = std::vector<size_t>{n * act_len};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end()); std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv( if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
} }
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf = auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; ? scale_inv_buf
output_tensor.set_columnwise_scale_inv( : colwise_scale_inv_buf;
colwise_scale_inv_buf->untyped_data(), if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), output_tensor.set_columnwise_scale_inv(
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
colwise_scale_inv_buf->dimensions().size() - 1), std::vector<size_t>{1});
colwise_scale_inv_buf->dimensions().back()}); } else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
} }
} }
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!"); NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
is_gated(act_type)),
"TE/common does not support delayed scaling for 2x with gated activations."); "TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) { if (is_dbias) {
...@@ -361,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -361,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias"),
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -15,47 +15,34 @@ ...@@ -15,47 +15,34 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
constexpr static size_t MXFP8_BLOCK_SIZE = 32; Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
Variadic_Result_Type output_list, int64_t num_gemms,
// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX) JAXX_Scaling_Mode scaling_mode, int64_t has_bias) {
Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr,
const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype,
uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr,
const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype,
uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms,
int32_t *dim_list_ptr, const int64_t &scaling_mode,
cudaStream_t stream) {
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms;
std::unique_ptr<int32_t[]> dim_list_host = std::make_unique<int32_t[]>(3 * num_gemms);
cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k], // A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose, // B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect: // on exiting this function, JAX expect:
// C: row-major with size [m, n]. // C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair: // cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose, // A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n]. // B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be: // If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m]. // C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if (num_gemms <= 0) {
return ffi_with_cuda_error_check();
}
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t expected_output_size = num_gemms + 1;
size_t actual_input_size = input_list.size();
size_t actual_output_size = output_list.size();
NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
expected_input_size, actual_input_size);
NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
expected_output_size, actual_output_size);
bool trans_lhs = true; bool trans_lhs = true;
bool trans_rhs = false; bool trans_rhs = false;
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
...@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
std::vector<NVTETensor> out_list; std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list; std::vector<NVTETensor> workspace_list;
int lhs_list_offset = 0;
int rhs_list_offset = num_gemms;
int lhs_sinv_list_offset = 2 * num_gemms;
int rhs_sinv_list_offset = 3 * num_gemms;
int bias_list_offset = 4 * num_gemms;
int out_list_offset = 0;
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
size_t m = dim_list_host[i * 3]; Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value();
size_t n = dim_list_host[i * 3 + 1]; Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value();
size_t k = dim_list_host[i * 3 + 2]; Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value();
Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value();
Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value();
DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type());
DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
void *lhs_ptr = lhs_i.untyped_data();
void *rhs_ptr = rhs_i.untyped_data();
void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *out_ptr = out_i->untyped_data();
// Placeholder for bias since it can be empty
DType bias_dtype = DType::kFloat32;
void *bias_ptr = nullptr;
auto lhs_shape_ = lhs_i.dimensions();
auto rhs_shape_ = rhs_i.dimensions();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t m = lhs_shape_[1];
size_t n = rhs_shape_[1];
size_t k = lhs_shape_[2];
auto lhs_shape = std::vector<size_t>{m, k}; auto lhs_shape = std::vector<size_t>{m, k};
auto rhs_shape = std::vector<size_t>{n, k}; auto rhs_shape = std::vector<size_t>{n, k};
...@@ -90,54 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -90,54 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto lhs_sinv_shape = std::vector<size_t>{1, 1}; auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1}; auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == NVTE_NO_SCALING || scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
auto lhs_i = TensorWrapper(static_cast<void *>(lhs_ptr), lhs_shape, lhs_dtype, nullptr, scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
nullptr, reinterpret_cast<float *>(lhs_sinv_ptr)); float *amax_dptr = nullptr;
auto rhs_i = TensorWrapper(static_cast<void *>(rhs_ptr), rhs_shape, rhs_dtype, nullptr, float *scale_dptr = nullptr;
nullptr, reinterpret_cast<float *>(rhs_sinv_ptr)); auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
lhs_wrapper_list.push_back(std::move(lhs_i)); reinterpret_cast<float *>(lhs_sinv_ptr));
rhs_wrapper_list.push_back(std::move(rhs_i)); auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) { reinterpret_cast<float *>(rhs_sinv_ptr));
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", lhs_wrapper_list.push_back(std::move(lhs_i_));
MXFP8_BLOCK_SIZE, k); rhs_wrapper_list.push_back(std::move(rhs_i_));
size_t sinv_k = k / MXFP8_BLOCK_SIZE; } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
lhs_sinv_shape[0] = m;
lhs_sinv_shape[1] = sinv_k;
rhs_sinv_shape[0] = n;
rhs_sinv_shape[1] = sinv_k;
// Note: the scale_inv array should have been swizzled in Python before lowering // Note: the scale_inv array should have been swizzled in Python before lowering
TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING); auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING); auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape); for (int i = 0; i < 2; i++) {
rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape); lhs_sinv_shape[i] = lhs_sinv_shape_[i];
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0, rhs_sinv_shape[i] = rhs_sinv_shape_[i];
lhs_sinv_shape); }
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0,
rhs_sinv_shape); NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode);
TensorWrapper lhs_i_(nvte_scaling_mode);
lhs_wrapper_list.push_back(std::move(lhs_i)); TensorWrapper rhs_i_(nvte_scaling_mode);
rhs_wrapper_list.push_back(std::move(rhs_i)); lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else { } else {
NVTE_ERROR("Unsupported scaling mode: ", scaling_mode); NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
} }
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype); auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype);
lhs_ptr += m * k * lhs_dtype_bytes;
rhs_ptr += n * k * rhs_dtype_bytes;
out_ptr += m * n * out_dtype_bytes;
lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes;
void *pre_gelu_ptr = nullptr; void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0}; auto bias_shape = std::vector<size_t>{0};
auto pre_gelu_shape = std::vector<size_t>{0}; auto pre_gelu_shape = std::vector<size_t>{0};
if (bias_ptr != nullptr) bias_shape[0] = n; if (has_bias) {
auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i);
Buffer_Type bias_i = bias_i_get.value();
bias_ptr = bias_i.untyped_data();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type());
bias_shape[0] = n;
}
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes;
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
out_wrapper_list.push_back(std::move(out_i)); out_wrapper_list.push_back(std::move(out_i_));
bias_wrapper_list.push_back(std::move(bias_i)); bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
...@@ -148,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -148,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
out_list.push_back(out_wrapper_list.back().data()); out_list.push_back(out_wrapper_list.back().data());
} }
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size}; auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) { for (int i = 0; i < num_streams; i++) {
auto workspace_i = auto workspace_i =
...@@ -165,49 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -165,49 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten,
Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten,
Buffer_Type dim_list, Result_Type out_flatten,
Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) {
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv_flatten.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv_flatten.untyped_data());
auto bias_ptr = reinterpret_cast<uint8_t *>(bias_flatten.untyped_data());
auto dim_list_ptr = reinterpret_cast<int32_t *>(dim_list.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type());
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type());
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(out_flatten->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace_flatten->untyped_data());
auto workspace_size = workspace_flatten->dimensions().back() / num_streams;
return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype,
rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype,
workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode,
stream);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind() FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_flatten .RemainingArgs() // input list
.Arg<Buffer_Type>() // lhs_sinv_flatten .RemainingRets() // output list
.Arg<Buffer_Type>() // rhs_flatten
.Arg<Buffer_Type>() // rhs_sinv_flatten
.Arg<Buffer_Type>() // bias_flatten
.Arg<Buffer_Type>() // dim_list
.Ret<Buffer_Type>() // out_flatten
.Ret<Buffer_Type>() // workspace_flatten
.Attr<int64_t>("num_gemms") .Attr<int64_t>("num_gemms")
.Attr<int64_t>("scaling_mode"), .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("has_bias"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
......
...@@ -34,11 +34,34 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -34,11 +34,34 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret; return ret;
} }
enum class QuantizeAxis { enum class QuantizeLayout {
ROWWISE, ROWWISE,
COLWISE, COLWISE,
ROWWISE_COLWISE, ROWWISE_COLWISE,
}; };
enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2,
};
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
break;
default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break;
}
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -14,7 +14,8 @@ namespace jax { ...@@ -14,7 +14,8 @@ namespace jax {
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype, DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode, NVTE_Norm_Type norm_type,
JAXX_Scaling_Mode scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin, bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training) { bool is_training) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
...@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype); auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto output_tensor = TensorWrapper(_scaling_mode);
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) { if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
int temp = 1; int temp = 1;
output_tensor.set_columnwise_data(static_cast<void *>(&temp), out_dtype, input_shape); output_tensor.set_columnwise_data(static_cast<void *>(&temp), out_dtype, input_shape);
} }
...@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr);
} else { } else {
NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma,
"rmsnorm doesn't support zero_centered_gamma."); "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(), nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(),
rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma,
...@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf,
int norm_type, bool zero_centered_gamma, double epsilon, int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, int scaling_mode, bool is_2x) { int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
...@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data(); auto *workspace = wkspace_buf->untyped_data();
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode);
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type); auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x); auto _is_2x = static_cast<bool>(is_2x);
...@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto output_tensor = TensorWrapper(_scaling_mode); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
scale_inv_buf->dimensions().back()}); scale_inv_buf->dimensions().back()});
} }
if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
...@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
workspace_tensor.data(), num_sm, zero_centered_gamma, stream); workspace_tensor.data(), num_sm, zero_centered_gamma, stream);
} else { } else {
NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma, NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma,
"rmsnorm doesn't support zero_centered_gamma."); "rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(), nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(),
rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma,
...@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<bool>("zero_centered_gamma") .Attr<bool>("zero_centered_gamma")
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<int64_t>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
......
...@@ -138,17 +138,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -138,17 +138,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("RMSNorm", NVTE_Norm_Type::RMSNorm) .value("RMSNorm", NVTE_Norm_Type::RMSNorm)
.export_values(); .export_values();
pybind11::enum_<NVTEScalingMode>(m, "NVTE_Scaling_Mode", pybind11::module_local()) pybind11::enum_<JAXX_Scaling_Mode>(m, "JAXX_Scaling_Mode", pybind11::module_local())
.value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
.value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.export_values(); .export_values();
pybind11::enum_<transformer_engine::jax::QuantizeAxis>(m, "QuantizeAxis", pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
pybind11::module_local()) pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE) .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE) .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values(); .export_values();
} }
......
...@@ -13,7 +13,9 @@ namespace transformer_engine { ...@@ -13,7 +13,9 @@ namespace transformer_engine {
namespace jax { namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) { DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size}; auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size}; auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
...@@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
int temp = 0; int temp = 0;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype); auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto output_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), output_shape, out_dtype);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
}
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
}
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
TensorWrapper dummy_workspace; TensorWrapper dummy_workspace;
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
...@@ -42,10 +71,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -42,10 +71,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf, Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
int64_t quantize_axis_enum, bool is_dbias) { bool is_dbias, int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -54,8 +83,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -54,8 +83,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum); auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto const quantize_axis = static_cast<QuantizeAxis>(quantize_axis_enum);
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data();
...@@ -63,9 +91,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -63,9 +91,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
void *workspace = workspace_buf->untyped_data(); void *workspace = workspace_buf->untyped_data();
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
int64_t input_ndim = input_dims.size();
if (flatten_axis < 0) flatten_axis += input_ndim;
NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");
auto workspace_dims = workspace_buf->dimensions(); auto workspace_dims = workspace_buf->dimensions();
auto m = product(input_dims, 0, input_dims.size() - 1); auto m = product(input_dims, 0, flatten_axis);
auto n = input_dims.back(); auto n = product(input_dims, flatten_axis, input_ndim);
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m}; auto output_trans_shape = std::vector<size_t>{n, m};
...@@ -73,39 +105,58 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -73,39 +105,58 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()}; std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (is_fp8_dtype(out_dtype)) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1}); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
} }
if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { if (quantize_layout == QuantizeLayout::COLWISE ||
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf = auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; ? scale_inv_buf
output_tensor.set_columnwise_scale_inv( : colwise_scale_inv_buf;
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->dimensions().size() - 1), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
colwise_scale_inv_buf->dimensions().back()}); std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
} }
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
...@@ -132,9 +183,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -132,9 +183,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_axis") .Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias"), .Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
......
...@@ -15,7 +15,11 @@ import jax ...@@ -15,7 +15,11 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize import QuantizerSet, noop_quantizer_set from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
)
def dense( def dense(
...@@ -23,6 +27,8 @@ def dense( ...@@ -23,6 +27,8 @@ def dense(
kernel: jnp.ndarray, kernel: jnp.ndarray,
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -48,12 +54,12 @@ def dense( ...@@ -48,12 +54,12 @@ def dense(
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
else: else:
output = _dense(x, kernel, bias, contracting_dims, quantizer_set) output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3,)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, quantizer_set): def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support This function implements the core dense layer transformation logic with support
...@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
kernel: Weight matrix kernel: Weight matrix
bias: Optional bias tensor bias: Optional bias tensor
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set) output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
)
return output return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
x_contracting_dims, k_contracting_dims = contracting_dims x_contracting_dims, k_contracting_dims = contracting_dims
casted_x = tex.quantize(x, quantizer_set.x) flatten_axis_x = -len(x_contracting_dims)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN # GEMM NN
output = tex.gemm( output = tex.gemm(
...@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
casted_kernel.get_colwise_tensor(), casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
use_bias = bias is not None use_bias = bias is not None
if use_bias: if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
...@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
kernel.shape, kernel.shape,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k,
) )
return output, ctx return output, ctx
def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns: Returns:
Tuple of gradients with respect to inputs Tuple of gradients with respect to inputs
""" """
...@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
kernel_shape, kernel_shape,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k,
) = ctx ) = ctx
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
)
# GEMM NT # GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
rowwise_casted_kernel, rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim), (g_constracting_dim, k_constracting_dim),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
...@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set return dgrad, wgrad, dbias, quantizer_set
......
...@@ -13,7 +13,6 @@ import jax.numpy as jnp ...@@ -13,7 +13,6 @@ import jax.numpy as jnp
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from jax import lax from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
...@@ -26,8 +25,14 @@ from ..layernorm_mlp import layernorm_mlp ...@@ -26,8 +25,14 @@ from ..layernorm_mlp import layernorm_mlp
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available from ..cpp_extensions import (
is_softmax_kernel_available,
jax_scaled_softmax,
jax_scaled_masked_softmax,
jax_scaled_upper_triang_masked_softmax,
)
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -167,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -167,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype = inputs.dtype input_dtype = inputs.dtype
logits = inputs logits = inputs
if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( # use primitives
if is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
): ):
if bias is not None: if bias is not None:
logits = logits + bias.astype(input_dtype) logits = logits + bias.astype(input_dtype)
...@@ -179,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -179,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
mask_ = None mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type) outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
# use default jax based implementation
else: else:
attention_bias = None
if mask is not None:
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, -1e10),
jnp.full(mask.shape, 0.0),
)
attention_bias = attention_bias.astype(input_dtype)
if bias is not None: if bias is not None:
attention_bias = _combine_biases(attention_bias, bias) logits = logits + bias.astype(input_dtype)
if attention_bias is not None:
logits = logits + attention_bias.astype(input_dtype)
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED if self.softmax_type is SoftmaxType.SCALED:
# and kernel is unavailable, then try on pure scaled softmax custom calls. outputs = jax_scaled_softmax(logits, self.scale_factor)
if is_softmax_kernel_available( elif self.softmax_type is SoftmaxType.SCALED_MASKED:
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
): elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
else: else:
outputs = jax_nn.softmax(logits * self.scale_factor) raise ValueError(
f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
)
assert input_dtype == outputs.dtype assert input_dtype == outputs.dtype
return outputs return outputs
...@@ -360,7 +356,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -360,7 +356,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
).value ).value
return QuantizeMeta(scale=scale, amax_history=amax_history) return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
x_meta = generate_quantize_meta("x") x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel") kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad") grad_meta = generate_quantize_meta("grad")
...@@ -406,6 +402,10 @@ class DenseGeneral(TransformerEngineBase): ...@@ -406,6 +402,10 @@ class DenseGeneral(TransformerEngineBase):
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -429,6 +429,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -429,6 +429,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -460,29 +461,35 @@ class DenseGeneral(TransformerEngineBase): ...@@ -460,29 +461,35 @@ class DenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, inputs.ndim) axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
if self.kernel_axes:
assert len(kernel_shape) == len(self.kernel_axes), (
"Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
)
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) ).astype(input_dtype)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
else: else:
bias = None bias = None
quantizer_set = self.generate_quantizer_set() quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
y = dense( y = dense(
inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set inputs,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -491,20 +498,14 @@ class DenseGeneral(TransformerEngineBase): ...@@ -491,20 +498,14 @@ class DenseGeneral(TransformerEngineBase):
*features[:-1], *features[:-1],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_init_shape = ( lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes( lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_shape,
self.dtype, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
...@@ -527,7 +528,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -527,7 +528,6 @@ class DenseGeneral(TransformerEngineBase):
y += jnp.reshape(bias, bias_shape) y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype assert y.dtype == input_dtype
y = y.reshape(*inputs.shape[: self.axis], *features)
return y return y
...@@ -678,6 +678,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -678,6 +678,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
assert self.axis == -1, "Only support axis = =-1 at this moment"
input_dtype = inputs.dtype input_dtype = inputs.dtype
ln_output = None ln_output = None
...@@ -692,10 +693,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -692,10 +693,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.enable_layernorm: if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1] features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters( scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type, self.layernorm_type,
(features,), (features,),
...@@ -731,17 +729,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -731,17 +729,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -756,11 +749,19 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -756,11 +749,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes, dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set) z = dense(
y,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
lora_a_kernel_shape = ( lora_a_kernel_shape = (
...@@ -768,20 +769,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -768,20 +769,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
*features[:-1], *features[:-1],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_init_shape = ( lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes( lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_shape,
self.dtype, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
...@@ -803,8 +798,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -803,8 +798,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) ).astype(input_dtype)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
if bias is not None: if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
...@@ -814,7 +808,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -814,7 +808,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
z = z / self.depth_scaling z = z / self.depth_scaling
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
z = z.reshape(*inputs.shape[: self.axis], *features) # z = z.reshape(*inputs.shape[: self.axis], *features)
return z, ln_output # dense_output, layer_norm_output return z, ln_output # dense_output, layer_norm_output
...@@ -989,6 +983,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -989,6 +983,8 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0") ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1") ffn2_quantizer_set = self.generate_quantizer_set("_1")
...@@ -1027,7 +1023,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1027,7 +1023,6 @@ class LayerNormMLP(TransformerEngineBase):
) )
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1] features = inputs.shape[-1]
...@@ -1071,7 +1066,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1071,7 +1066,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations = len(normalized_acts) num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes( kernel_1 = nn_partitioning.param_with_axes(
"wi_kernel", "wi_kernel",
kernel_1_init, kernel_1_init,
...@@ -1081,13 +1076,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1081,13 +1076,10 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
axes=self.kernel_axes_1, axes=self.kernel_axes_1,
) )
kernel_1_compute_shape = (
reduce(operator.mul, [y.shape[ax] for ax in axis], 1),
num_activations * self.intermediate_dim,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape)
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size) hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
...@@ -1098,26 +1090,20 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1098,26 +1090,20 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
axes=self.kernel_axes_2, axes=self.kernel_axes_2,
) )
kernel_2_compute_shape = (
self.intermediate_dim,
reduce(operator.mul, hidden_size_tuple, 1),
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape)
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if self.use_bias: if self.use_bias:
bias_1_shape = num_activations * self.intermediate_dim bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes( bias_1 = nn_partitioning.param_with_axes(
"wi_bias", "wi_bias",
self.bias_init, self.bias_init,
bias_1_shape, bias_1_shape,
self.dtype, self.dtype,
axes=self.bias_axes_1, axes=self.bias_axes_1,
) ).astype(input_dtype)
bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype)
bias_2_shape = (hidden_size,) bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes( bias_2 = nn_partitioning.param_with_axes(
...@@ -1126,8 +1112,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1126,8 +1112,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2_shape, bias_2_shape,
self.dtype, self.dtype,
axes=self.bias_axes_2, axes=self.bias_axes_2,
) ).astype(input_dtype)
bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype)
else: else:
bias_1 = None bias_1 = None
bias_2 = None bias_2 = None
...@@ -1136,8 +1121,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1136,8 +1121,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name = "ffn2" ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp: if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_mlp( out = layernorm_mlp(
y, y,
scale, scale,
...@@ -1150,6 +1133,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1150,6 +1133,8 @@ class LayerNormMLP(TransformerEngineBase):
norm_input_axes=self.layernorm_input_axes, norm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes, dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes, dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name, ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
...@@ -1170,6 +1155,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1170,6 +1155,7 @@ class LayerNormMLP(TransformerEngineBase):
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes, dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
else: else:
...@@ -1178,35 +1164,31 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1178,35 +1164,31 @@ class LayerNormMLP(TransformerEngineBase):
y, y,
kernel_1, kernel_1,
contracting_dims=(axis, contract_ind), contracting_dims=(axis, contract_ind),
input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = ( wi_lora_a_kernel_each_shape = (
kernel_1_compute_shape[0], kernel_1_each_shape[: len(axis)],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_shape = (
kernel_1_each_shape[0],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_each_shape = (
kernel_1_each_shape[0],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape) wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
wi_lora_a_kernel = nn_partitioning.param_with_axes( wi_lora_a_kernel = nn_partitioning.param_with_axes(
"wi_lora_a_kernel", "wi_lora_a_kernel",
kernel_1_init, kernel_1_init,
num_activations, num_activations,
-1, -2,
wi_lora_a_kernel_init_each_shape, wi_lora_a_kernel_each_shape,
self.dtype, self.dtype,
axes=wi_lora_a_kernel_axes, axes=wi_lora_a_kernel_axes,
) )
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype) wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = ( wi_lora_b_kernel_shape = (
...@@ -1227,7 +1209,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1227,7 +1209,7 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation( x += _apply_low_rank_adaptation(
y, y,
axis, axis,
num_activations * self.intermediate_dim, (num_activations, self.intermediate_dim),
wi_lora_a_kernel, wi_lora_a_kernel,
wi_lora_b_kernel, wi_lora_b_kernel,
self.low_rank_adaptation_alpha, self.low_rank_adaptation_alpha,
...@@ -1241,11 +1223,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1241,11 +1223,12 @@ class LayerNormMLP(TransformerEngineBase):
z = activation(x, normalized_acts) z = activation(x, normalized_acts)
else: else:
activations = [] activations = []
x = jnp.split(x, num_activations, axis=-1) x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(normalized_acts): for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = reduce(operator.mul, activations) z = reduce(operator.mul, activations)
z = jnp.squeeze(z, axis=-2)
z = z.astype(input_dtype) z = z.astype(input_dtype)
z = nn.Dropout( z = nn.Dropout(
...@@ -1259,7 +1242,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1259,7 +1242,12 @@ class LayerNormMLP(TransformerEngineBase):
# DenseGeneral 2 # DenseGeneral 2
out = dense( out = dense(
z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set z,
kernel_2,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
......
...@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if mask is not None: if mask is not None:
mask = apply_swa_mask(mask) mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: if mask is not None:
return SoftmaxType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.CAUSAL_MASK:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: if attn_mask_type is AttnMaskType.NO_MASK:
if mask is not None:
return SoftmaxType.SCALED_MASKED, mask
return SoftmaxType.SCALED, mask return SoftmaxType.SCALED, mask
raise ValueError( raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type=" f"Unsupported {attn_mask_type=}, supported attn_mask_type="
...@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
.. note:: THD format only supports 'padding' or 'causal_padding' mask type. .. note:: THD format only supports 'padding' or 'causal_padding' mask type.
attn_mask_type mask/sequence_descriptor SWA softmax type
--------------------------------------------------------------------------------------------
no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED
attn_bias_type: Optional[str], default = None attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention. Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
......
...@@ -33,10 +33,9 @@ def layernorm_dense( ...@@ -33,10 +33,9 @@ def layernorm_dense(
norm_type: str = "layernorm", norm_type: str = "layernorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
# The logic axes of sharding constraint to the dot input.
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation. """Apply layer normalization followed by dense layer transformation.
...@@ -56,6 +55,7 @@ def layernorm_dense( ...@@ -56,6 +55,7 @@ def layernorm_dense(
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: Set of quantizers for different tensor types quantizer_set: Set of quantizers for different tensor types
Returns: Returns:
...@@ -78,6 +78,7 @@ def layernorm_dense( ...@@ -78,6 +78,7 @@ def layernorm_dense(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -91,6 +92,7 @@ def layernorm_dense( ...@@ -91,6 +92,7 @@ def layernorm_dense(
7, 7,
8, 8,
9, 9,
10,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -104,6 +106,7 @@ def _layernorm_dense( ...@@ -104,6 +106,7 @@ def _layernorm_dense(
epsilon: float, epsilon: float,
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
quantizer_set, quantizer_set,
): ):
"""Internal implementation of layernorm_dense with custom VJP. """Internal implementation of layernorm_dense with custom VJP.
...@@ -139,6 +142,7 @@ def _layernorm_dense( ...@@ -139,6 +142,7 @@ def _layernorm_dense(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule( ...@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for layernorm_dense. """Forward pass rule for layernorm_dense.
...@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule( ...@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
...@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule( ...@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
norm_type, norm_type,
quantizer_set.x, quantizer_set.x,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...) # Kernel in (hidden_in, hidden_out...)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel) flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
...@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule( ...@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims, k_contracting_dims,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis,
) )
return output, ctx return output, ctx
...@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule( ...@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument dot_input_axes, # pylint: disable=unused-argument
kernel_axes,
ctx, ctx,
grad, grad,
): ):
...@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule( ...@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
k_contracting_dims_in_fwd, k_contracting_dims_in_fwd,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis,
) = ctx ) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes) casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple( g_constracting_dim = tuple(
...@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule( ...@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
(x_constracting_dim, g_constracting_dim), (x_constracting_dim, g_constracting_dim),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
dx, dgamma, dbeta = tex.normalization_bwd( dx, dgamma, dbeta = tex.normalization_bwd(
dgrad, dgrad,
x, x,
......
...@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .sharding import get_non_contracting_logical_axes
def layernorm_mlp( def layernorm_mlp(
...@@ -37,6 +38,8 @@ def layernorm_mlp( ...@@ -37,6 +38,8 @@ def layernorm_mlp(
norm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
kernel_1_axes: Tuple[str, ...] = None,
kernel_2_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
...@@ -66,6 +69,8 @@ def layernorm_mlp( ...@@ -66,6 +69,8 @@ def layernorm_mlp(
norm_input_axes: Logical axes for sharding the layernorm input norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication
kernel_1_axes: Logical axes for sharding the first weight matrix
kernel_2_axes: Logical axes for sharding the second weight matrix
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
...@@ -109,6 +114,8 @@ def layernorm_mlp( ...@@ -109,6 +114,8 @@ def layernorm_mlp(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -117,7 +124,7 @@ def layernorm_mlp( ...@@ -117,7 +124,7 @@ def layernorm_mlp(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -132,6 +139,8 @@ def _layernorm_mlp( ...@@ -132,6 +139,8 @@ def _layernorm_mlp(
norm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
kernel_1_axes: Tuple[str, ...],
kernel_2_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
...@@ -179,6 +188,8 @@ def _layernorm_mlp( ...@@ -179,6 +188,8 @@ def _layernorm_mlp(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule( ...@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
Returns: Returns:
Tuple of (output, context) for automatic differentiation Tuple of (output, context) for automatic differentiation
""" """
del kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len * intermediate) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in) # Kernel_2 should be in shape of (intermediate, hidden_in)
assert len(kernel_1.shape) == 2 assert len(kernel_1.shape) == 3
assert len(kernel_2.shape) == 2 assert len(kernel_2.shape) == 2
assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type) assert kernel_1.shape[-2] == len(activation_type)
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
use_bias_1 = bias_1 is not None use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None use_bias_2 = bias_1 is not None
...@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
) )
casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out) # (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm( dot_1_output = tex.gemm(
...@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule( ...@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_1.get_colwise_tensor(), casted_kernel_1.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1: if use_bias_1:
bias_1_shape = bias_1.shape bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
...@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule( ...@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2: if use_bias_2:
bias_2_shape = bias_2.shape bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
...@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument kernel_1_axes,
ffn2_ckpt_name, # pylint: disable=unused-argument kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type, activation_type,
ctx, ctx,
grad, grad,
...@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
( (
x, x,
mu, mu,
...@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule( ...@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_2 = tuple( g_contracting_dims_2 = tuple(
range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_constracting_dim_2 = tuple( k_contracting_dims_2 = tuple(
dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
) )
...@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule( ...@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
dgrad_2 = tex.gemm( dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2, rowwise_casted_kernel_2,
(g_constracting_dim_2, k_constracting_dim_2), (g_contracting_dims_2, k_contracting_dims_2),
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
x_constracting_dim = g_constracting_dim = tuple( x_contracting_dims = g_contracting_dims = tuple(
range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
) )
...@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule( ...@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
wgrad_2 = tex.gemm( wgrad_2 = tex.gemm(
colwise_casted_act_out, colwise_casted_act_out,
casted_grad.get_colwise_tensor(), casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim), (x_contracting_dims, g_contracting_dims),
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
casted_dact_out, dbias_1 = tex.quantize_dact_dbias( casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
dgrad_2, dgrad_2,
...@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule( ...@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1 = tuple( dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim) g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_constracting_dim_1 = tuple( k_contracting_dims_1 = tuple(
dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
) )
...@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule( ...@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
dgrad_1 = tex.gemm( dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(), casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1, rowwise_casted_kernel_1,
(g_constracting_dim_1, k_constracting_dim_1), (g_contracting_dims_1, k_contracting_dims_1),
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
# TN GEMM # TN GEMM
# (hidden, batch...) x (hidden, batch...) # (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm( wgrad_1 = tex.gemm(
colwise_casted_ln_out, colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(), casted_dact_out.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim), (x_contracting_dims, g_contracting_dims),
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
dx, dgamma, dbeta = tex.normalization_bwd( dx, dgamma, dbeta = tex.normalization_bwd(
dgrad_1, dgrad_1,
x, x,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Praxis related Modules"""
from .module import FusedSoftmax, LayerNorm
from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer
from .transformer import DotProductAttention, MultiHeadAttention
from .transformer import RelativePositionBiases, TransformerLayer
from ..flax.transformer import TransformerLayerType
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules
"""
from dataclasses import field
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union
from praxis import pax_fiddle
from praxis.base_layer import init_var
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams, WeightHParamsCollection
from praxis.layers import flax_adapter
from praxis.pytypes import JTensor
from ..fp8 import FP8Helper
from ..flax.module import DenseGeneral, LayerNormDenseGeneral
from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP
from ..flax.module import Softmax
from ..softmax import SoftmaxType
def _generate_ln_scale_init(scale_init):
if scale_init is not None:
return TransformerEngineBaseLayer.generate_params_init("scale", scale_init)
return scale_init
class TransformerEngineBaseLayer(BaseLayer):
"""TransformerEngineBaseLayer"""
logical_axes_rules: Tuple[Tuple, ...] = None
@staticmethod
def generate_params_init(name: str, initializer: WeightInit):
"""generate_params_init"""
def kernel_init(key, shape, dtype):
wp = WeightHParams(shape=shape, init=initializer, dtype=dtype)
return init_var(wp, key, name)
return kernel_init
def create_layer(self, name, flax_module_cls):
"""create_layer"""
fp8_collection_map = {
FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.OVERWRITE_WITH_GRADIENT,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION,
]
}
flax_module_p = pax_fiddle.Config(
flax_adapter.FlaxModuleAdapter,
module_factory_method=flax_module_cls,
logical_axes_rules=self.logical_axes_rules,
var_collection_map=fp8_collection_map,
ici_mesh_shape=self.ici_mesh_shape,
dcn_mesh_shape=self.dcn_mesh_shape,
mesh_axis_names=self.mesh_axis_names,
)
self.create_child(name, flax_module_p.clone())
class LayerNorm(TransformerEngineBaseLayer):
"""LayerNorm"""
epsilon: float = 1e-6
layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
ln_cls = partial(
flax_LayerNorm,
epsilon=self.epsilon,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", self.bias_init),
bias_axes=self.bias_axes,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("layer_norm", ln_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.layer_norm(x)
class FusedSoftmax(TransformerEngineBaseLayer):
"""FusedSoftmax"""
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
def setup(self) -> None:
"""setup"""
super().setup()
fused_softmax_cls = partial(
Softmax, scale_factor=self.scale_factor, softmax_type=self.softmax_type
)
self.create_layer("fused_softmax", fused_softmax_cls)
def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JTensor:
"""__call__"""
return self.fused_softmax(x, mask, bias)
class Linear(TransformerEngineBaseLayer):
"""Linear"""
out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
dense_general_cls = partial(
DenseGeneral,
features=self.out_features,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("linear", dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.linear(x)
class LayerNormLinear(TransformerEngineBaseLayer):
"""LayerNormLinear"""
out_features: int = 512
enable_layernorm: bool = True
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
depth_scaling: float = None
def setup(self) -> None:
"""setup"""
super().setup()
ln_dense_general_cls = partial(
LayerNormDenseGeneral,
features=self.out_features,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init
),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling,
)
self.create_layer("ln_linear", ln_dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.ln_linear(x)
class LayerNormMLP(TransformerEngineBaseLayer):
"""LayerNormMLP"""
intermediate_dim: int = 2048
enable_layernorm: bool = True
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",)
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
ln_mlp_cls = partial(
flax_LayerNormMLP,
intermediate_dim=self.intermediate_dim,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init
),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes_1=self.kernel_axes_1,
kernel_axes_2=self.kernel_axes_2,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes_1=self.bias_axes_1,
bias_axes_2=self.bias_axes_2,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output,
activations=self.activations,
intermediate_dropout_rate=self.intermediate_dropout_rate,
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("ln_mlp", ln_mlp_cls)
def __call__(self, x: JTensor, deterministic: bool = False) -> JTensor:
"""__call__"""
return self.ln_mlp(x, deterministic)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from dataclasses import field
from functools import partial
from typing import Optional, Sequence, Tuple
import warnings
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer
from ..flax.transformer import TransformerLayerType
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
from ..attention import AttnBiasType, AttnMaskType
class RelativePositionBiases(TransformerEngineBaseLayer):
"""RelativePositionBiases"""
num_buckets: int = 32
max_distance: int = 128
num_attention_heads: int = 64
embedding_init: WeightInit = None
embedding_axes: Tuple[str, ...] = ()
@staticmethod
def generate_embedding_init(init, num_attention_heads, num_buckets):
"""generate_embedding_init"""
embedding_init = init
if embedding_init is None:
rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
return embedding_init
def setup(self) -> None:
"""setup"""
super().setup()
embedding_init = RelativePositionBiases.generate_embedding_init(
self.embedding_init, self.num_attention_heads, self.num_buckets
)
rpb_cls = partial(
flax_RelativePositionBiases,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
num_attention_heads=self.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
embedding_axes=self.embedding_axes,
dtype=self.dtype,
)
self.create_layer("relative_position_bias", rpb_cls)
def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
"""__call__"""
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)
class DotProductAttention(TransformerEngineBaseLayer):
"""DotProductAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f"{self.head_dim=}"
assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
dpa_cls = partial(
flax_DotProductAttention,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)
self.create_layer("dot_product_attention", dpa_cls)
def __call__(
self,
query: JTensor,
key: JTensor,
value: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
deterministic: bool = False,
) -> JTensor:
"""__call__"""
return self.dot_product_attention(
query, key, value, mask, bias, deterministic=deterministic
)
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.0
dropout_rng_name: str = "dropout"
input_layernorm: bool = True
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None
# Deprecated parameters
num_heads: Optional[int] = None
dropout_rate: Optional[float] = None
output_layernorm: Optional[bool] = None
apply_residual_connection_post_layernorm: Optional[bool] = None
fuse_qkv: Optional[bool] = None
def __post_init__(self):
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.",
DeprecationWarning,
)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.",
DeprecationWarning,
)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning,
)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.",
DeprecationWarning,
)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f"{self.head_dim=}"
assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
mha_cls = partial(
flax_MultiHeadAttention,
dtype=self.dtype,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
input_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
return_layernorm_output=self.return_layernorm_output,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits,
window_size=self.window_size,
)
self.create_layer("multi_head_attn", mha_cls)
def __call__(
self,
inputs_q: JTensor,
inputs_kv: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
decode: bool = False,
deterministic: bool = False,
) -> JTensor:
"""__call__"""
return self.multi_head_attn(
inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic
)
class TransformerLayer(TransformerEngineBaseLayer):
"""TransformerLayer"""
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: Optional[int] = None
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = "causal"
self_attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
relative_embedding_flax_module = None
if self.enable_relative_embedding and self.relative_embedding is not None:
assert self.relative_embedding.num_attention_heads == self.num_attention_heads, (
"TransformerLayer.relative_embedding.num_attention_heads shoule be"
"the same as TransformerLayer.num_attention_heads."
)
embedding_init = RelativePositionBiases.generate_embedding_init(
self.relative_embedding.embedding_init,
self.relative_embedding.num_attention_heads,
self.relative_embedding.num_buckets,
)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=self.relative_embedding.num_buckets,
max_distance=self.relative_embedding.max_distance,
num_attention_heads=self.relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
embedding_axes=self.relative_embedding.embedding_axes,
dtype=self.relative_embedding.dtype,
)
transformerlayer_cls = partial(
flax_TransformerLayer,
dtype=self.dtype,
hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
hidden_dropout=self.hidden_dropout,
hidden_dropout_dims=self.hidden_dropout_dims,
attention_dropout=self.attention_dropout,
intermediate_dropout=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
dropout_rng_name=self.dropout_rng_name,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", self.params_init
),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", self.params_init
),
mlp_activations=self.mlp_activations,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
self_attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
window_size=self.window_size,
)
self.create_layer("transformerlayer", transformerlayer_cls)
def __call__(
self,
inputs: JTensor,
encoded: JTensor = None,
attention_mask: JTensor = None,
encoder_decoder_mask: JTensor = None,
deterministic: bool = False,
decode: bool = False,
max_decode_length: bool = None,
) -> JTensor:
"""__call__"""
return self.transformerlayer(
inputs,
encoded,
attention_mask,
encoder_decoder_mask,
deterministic,
decode,
max_decode_length,
)
...@@ -57,26 +57,35 @@ class Dequantizer: ...@@ -57,26 +57,35 @@ class Dequantizer:
data = scaled_tensor.data.astype(jnp.float32) data = scaled_tensor.data.astype(jnp.float32)
data_shape = data.shape data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32) scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
flatten_axis = scaled_tensor.flatten_axis
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaled_tensor.scaling_mode.get_scale_shape( scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
data = data.reshape( data = data.reshape(
*data_shape[:-2], *data_shape[: flatten_axis - 1],
scale_shape[-2], scale_shape[flatten_axis - 1],
int(data_shape[-2] / scale_shape[-2]), int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1], scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]), int(data_shape[-1] / scale_shape[-1]),
) )
scale = jnp.expand_dims(scale, axis=(-1, -3))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape( return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape data_shape
) )
funcs = { funcs = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling, ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling,
} }
@staticmethod @staticmethod
......
...@@ -27,7 +27,14 @@ from transformer_engine.jax.sharding import global_shard_guard, MeshResource ...@@ -27,7 +27,14 @@ from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
__all__ = ["QuantizeConfig", "fp8_autocast", "is_fp8_available", "update_collections"] __all__ = [
"QuantizeConfig",
"fp8_autocast",
"is_fp8_available",
"update_collections",
"get_delayed_scaling",
"NVTE_FP8_COLLECTION_NAME",
]
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
...@@ -87,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: ...@@ -87,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
A tuple of (bool, str) indicating support and any error message A tuple of (bool, str) indicating support and any error message
""" """
gpu_arch = get_device_compute_capability(gpu_id) gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _check_delayed_scaling_fp8_support(gpu_arch) return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _check_block_scaling_fp8_support(gpu_arch) return _check_block_scaling_fp8_support(gpu_arch)
return (False, "Unsupported scaling_mode!") return (False, "Unsupported scaling_mode!")
def is_fp8_available( def is_fp8_available(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
gpu_id=None, gpu_id=None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
"""Check if FP8 is available for the given scaling mode and GPU. """Check if FP8 is available for the given scaling mode and GPU.
...@@ -172,37 +179,12 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ...@@ -172,37 +179,12 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
ValueError: If the recipe type is not supported ValueError: If the recipe type is not supported
""" """
if isinstance(fp8_recipe, recipe.DelayedScaling): if isinstance(fp8_recipe, recipe.DelayedScaling):
return ScalingMode.NVTE_DELAYED_TENSOR_SCALING return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.NVTE_MXFP8_1D_SCALING return ScalingMode.MXFP8_1D_SCALING
raise ValueError("Invalid fp8_recipe!") raise ValueError("Invalid fp8_recipe!")
def update_collections(new: Collection, original: Collection) -> Collection:
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
class QuantizeConfig: class QuantizeConfig:
"""Configuration class for quantization settings. """Configuration class for quantization settings.
...@@ -227,7 +209,7 @@ class QuantizeConfig: ...@@ -227,7 +209,7 @@ class QuantizeConfig:
INITIALIZED = False INITIALIZED = False
MARGIN: float = 0.0 MARGIN: float = 0.0
COLLECTION_NAME: str = "quantize_meta" COLLECTION_NAME: str = "fp8_metas"
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
...@@ -235,7 +217,7 @@ class QuantizeConfig: ...@@ -235,7 +217,7 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False IF_QUANTIZE_2X: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling # DelayedScaling
AMAX_HISTORY_LEN: int = 1024 AMAX_HISTORY_LEN: int = 1024
...@@ -271,11 +253,11 @@ class QuantizeConfig: ...@@ -271,11 +253,11 @@ class QuantizeConfig:
cls.MARGIN = 0.0 cls.MARGIN = 0.0
cls.FP8_FORMAT = recipe.Format.HYBRID cls.FP8_FORMAT = recipe.Format.HYBRID
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.FP8_2X_ACC_FPROP = False cls.FP8_2X_ACC_FPROP = False
cls.FP8_2X_ACC_DGRAD = False cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.IF_QUANTIZE_2X = False cls.IF_QUANTIZE_2X = False
# DelayedScaling # DelayedScaling
cls.AMAX_HISTORY_LEN = 1024 cls.AMAX_HISTORY_LEN = 1024
...@@ -414,3 +396,56 @@ def fp8_autocast( ...@@ -414,3 +396,56 @@ def fp8_autocast(
yield yield
finally: finally:
Config.finalize() Config.finalize()
def get_delayed_scaling():
r"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = (
"max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
)
return recipe.DelayedScaling(
margin=int(QuantizeConfig.MARGIN),
fp8_format=QuantizeConfig.FP8_FORMAT,
amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo,
)
def update_collections(new: Collection, original: Collection) -> Collection:
r"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
...@@ -14,7 +14,7 @@ from typing import Union, Optional ...@@ -14,7 +14,7 @@ from typing import Union, Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
...@@ -24,7 +24,7 @@ from .helper import ( ...@@ -24,7 +24,7 @@ from .helper import (
) )
__all__ = [ __all__ = [
"QuantizeAxis", "QuantizeLayout",
"Quantizer", "Quantizer",
"QuantizerSet", "QuantizerSet",
"DelayedScaleQuantizer", "DelayedScaleQuantizer",
...@@ -45,12 +45,12 @@ class Quantizer(ABC): ...@@ -45,12 +45,12 @@ class Quantizer(ABC):
Attributes: Attributes:
q_dtype: The data type for quantized values q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both) q_layout: The quantization axis (row-wise, column-wise, or both)
""" """
q_dtype: jnp.dtype q_dtype: jnp.dtype
scaling_mode: ScalingMode scaling_mode: ScalingMode
q_axis: QuantizeAxis q_layout: QuantizeLayout
def tree_flatten(self): def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations. """Flatten the quantizer for JAX tree operations.
...@@ -59,7 +59,7 @@ class Quantizer(ABC): ...@@ -59,7 +59,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations Tuple of (children, aux_data) for tree operations
""" """
children = () children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data) return (children, aux_data)
@classmethod @classmethod
...@@ -85,30 +85,31 @@ class Quantizer(ABC): ...@@ -85,30 +85,31 @@ class Quantizer(ABC):
Returns: Returns:
True if using both row-wise and column-wise quantization True if using both row-wise and column-wise quantization
""" """
return self.q_axis == QuantizeAxis.ROWWISE_COLWISE return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
@abstractmethod @abstractmethod
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout. """Get the data data_layout.
Returns: Returns:
Data layout in string format Data data_layout in string format
""" """
@abstractmethod @abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Core quantization function to be implemented by subclasses. """Core quantization function to be implemented by subclasses.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype dq_dtype: Data type for dequantized values, default is x.dtype
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None): def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1):
"""Quantize a tensor using the internal _quantize_func(). """Quantize a tensor using the internal _quantize_func().
Args: Args:
...@@ -116,21 +117,26 @@ class Quantizer(ABC): ...@@ -116,21 +117,26 @@ class Quantizer(ABC):
is_rowwise: Whether to use row-wise quantization is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
if (is_rowwise and is_colwise) or self.is_2x2x(): if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) colwise_tensor = self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise: if is_colwise:
return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) return self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return self._quantize_func(x, dq_dtype=dq_dtype) return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def get_scale_shapes(self, data_shape, is_padded=True): def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1):
"""Get shapes for scale tensors. """Get shapes for scale tensors.
Args: Args:
...@@ -140,7 +146,7 @@ class Quantizer(ABC): ...@@ -140,7 +146,7 @@ class Quantizer(ABC):
Returns: Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape) Tuple of (rowwise_scale_shape, colwise_scale_shape)
""" """
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded) return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
def get_scale_dtype(self): def get_scale_dtype(self):
"""Get the data type for scale tensors. """Get the data type for scale tensors.
...@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor scale: Current scaling factor
amax_history: History of maximum absolute values amax_history: History of maximum absolute values
""" """
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field( amax_history: jnp.ndarray = field(
...@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations Tuple of (children, aux_data) for tree operations
""" """
children = (self.scale, self.amax_history) children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data) return (children, aux_data)
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout string. """Get the data data_layout string.
Returns: Returns:
Data layout in string format Data data_layout in string format
Raises: Raises:
ValueError: If quantization axis is invalid ValueError: If quantization axis is invalid
""" """
layout = "NT" data_layout = "NT"
if self.q_axis == QuantizeAxis.ROWWISE_COLWISE: if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return layout return data_layout
if self.q_axis == QuantizeAxis.ROWWISE: if self.q_layout == QuantizeLayout.ROWWISE:
return layout[0] return data_layout[0]
if self.q_axis == QuantizeAxis.COLWISE: if self.q_layout == QuantizeLayout.COLWISE:
return layout[1] return data_layout[1]
raise ValueError(f"Invalid q_axis: {self.q_axis}") raise ValueError(f"Invalid q_layout: {self.q_layout}")
def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8. """Quantize function helper for delayed scaling FP8.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
...@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
scale_inv=scale_inv, scale_inv=scale_inv,
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None): def quantize(
self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
):
"""Quantize a tensor using the internal _quantize_func(). """Quantize a tensor using the internal _quantize_func().
Args: Args:
...@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
is_rowwise: Whether to use row-wise quantization is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = ( is_rowwise = (
is_rowwise is_rowwise
if is_rowwise is not None if is_rowwise is not None
else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x()) else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
) )
is_colwise = ( is_colwise = (
is_colwise is_colwise
if is_colwise is not None if is_colwise is not None
else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x()) else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
) )
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = None colwise_tensor = None
if is_colwise: if is_colwise:
colwise_tensor = ScaledTensorFactory.create_1x( colwise_tensor = ScaledTensorFactory.create_1x(
data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))), data=jnp.transpose(
rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
),
scale_inv=rowwise_tensor.scale_inv, scale_inv=rowwise_tensor.scale_inv,
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
is_colwise=True, is_colwise=True,
layout="T", data_layout="T",
flatten_axis=flatten_axis,
) )
if is_colwise and is_rowwise: if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
...@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer): ...@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
""" """
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout string. """Get the data data_layout string.
Returns: Returns:
Data layout in string format Data data_layout in string format
""" """
if self.is_2x2x(): if self.is_2x2x():
return "NN" return "NN"
return "N" return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8. """Quantize function helper for block scaling FP8.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
# TODO(Phuong): use quantize_func from JAX # TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
assert (
0 <= flatten_axis < x.ndim
), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False) scale_shape = self.scaling_mode.get_scale_shape(
x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale_dtype = self.scaling_mode.get_scale_dtype() scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape( x = x.reshape(
*x_shape[:-2], *x_shape[: flatten_axis - 1],
scale_shape[-2], scale_shape[flatten_axis - 1],
int(x_shape[-2] / scale_shape[-2]), int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*x_shape[flatten_axis:-1],
scale_shape[-1], scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]), int(x_shape[-1] / scale_shape[-1]),
) )
amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True) amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32) MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
scales = amax.astype(jnp.float32) / MAX scales = amax.astype(jnp.float32) / MAX
...@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer): ...@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
self.scaling_mode, self.scaling_mode,
is_colwise=is_colwise, is_colwise=is_colwise,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
def _cast_to_e8m0_with_rounding_up(self, scales): def _cast_to_e8m0_with_rounding_up(self, scales):
...@@ -500,8 +530,8 @@ class QuantizerFactory: ...@@ -500,8 +530,8 @@ class QuantizerFactory:
""" """
quantizer_type_map = { quantizer_type_map = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
} }
@staticmethod @staticmethod
...@@ -509,7 +539,7 @@ class QuantizerFactory: ...@@ -509,7 +539,7 @@ class QuantizerFactory:
n_quantizers: int = 1, n_quantizers: int = 1,
scaling_mode: ScalingMode = None, scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None, q_dtype: jnp.dtype = None,
q_axis: QuantizeAxis = None, q_layout: QuantizeLayout = None,
**kwargs, **kwargs,
) -> Quantizer: ) -> Quantizer:
"""Create one or more quantizers with specified parameters. """Create one or more quantizers with specified parameters.
...@@ -518,15 +548,17 @@ class QuantizerFactory: ...@@ -518,15 +548,17 @@ class QuantizerFactory:
n_quantizers: Number of quantizers to create n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use scaling_mode: Scaling mode to use
q_dtype: Quantization data type q_dtype: Quantization data type
q_axis: Quantization axis q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
Returns: Returns:
A single quantizer or tuple of quantizers A single quantizer or tuple of quantizers
""" """
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING): # import pdb; pdb.set_trace()
if scaling_mode == ScalingMode.NO_SCALING:
quantizers = [None] * n_quantizers quantizers = [None] * n_quantizers
else: else:
quantizers = [] quantizers = []
...@@ -534,7 +566,7 @@ class QuantizerFactory: ...@@ -534,7 +566,7 @@ class QuantizerFactory:
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append( quantizers.append(
quantizer_type( quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
) )
) )
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
...@@ -554,11 +586,11 @@ class QuantizerFactory: ...@@ -554,11 +586,11 @@ class QuantizerFactory:
A QuantizerSet instance A QuantizerSet instance
""" """
if is_2x2x: if is_2x2x:
q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else: else:
q_axis_x = QuantizeAxis.ROWWISE q_layout_x = QuantizeLayout.ROWWISE
q_axis_kernel = QuantizeAxis.COLWISE q_layout_kernel = QuantizeLayout.COLWISE
q_axis_dgrad = None q_layout_dgrad = None
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set") quantize_meta_set = kwargs.get("quantize_meta_set")
...@@ -577,9 +609,11 @@ class QuantizerFactory: ...@@ -577,9 +609,11 @@ class QuantizerFactory:
else: else:
args_x = args_kernel = args_grad = {} args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x) q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x)
q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel) q_kernel = QuantizerFactory.create(
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad) 1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel
)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod @staticmethod
...@@ -618,4 +652,4 @@ class QuantizerFactory: ...@@ -618,4 +652,4 @@ class QuantizerFactory:
return q_set[0] if len(q_set) == 1 else tuple(q_set) return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING) noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING)
...@@ -16,11 +16,33 @@ from typing import Tuple, Dict ...@@ -16,11 +16,33 @@ from typing import Tuple, Dict
from functools import reduce from functools import reduce
import operator import operator
from jax.experimental.custom_partitioning import CompoundFactor
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode
__all__ = ["ScalingMode"]
__all__ = ["QuantizeShardyRules", "ScalingMode"]
@dataclass
class QuantizeShardyRules:
"""Information necessary to shard scale tensors with Shardy.
Attributes:
input_spec: Specification for the input axes
rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
the axes in `input_spec`
colwise_rule: Likewise for the column-wise scale tensor.
factor_sizes: For block scaling, contains the block size factor, which is
used in `input_spec`.
"""
input_spec: Tuple[str]
rowwise_rule: Tuple[str]
colwise_rule: Tuple[str]
factor_sizes: Dict[str, int]
class ScalingModeMetadataImpl(ABC): class ScalingModeMetadataImpl(ABC):
...@@ -40,7 +62,11 @@ class ScalingModeMetadataImpl(ABC): ...@@ -40,7 +62,11 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod @abstractmethod
def get_scale_shape( def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors. """Get the shape for scale tensors.
...@@ -48,11 +74,26 @@ class ScalingModeMetadataImpl(ABC): ...@@ -48,11 +74,26 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
The shape for scale tensors The shape for scale tensors
""" """
@abstractmethod
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for delayed scaling mode. """Implementation for delayed scaling mode.
...@@ -69,7 +110,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -69,7 +110,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return jnp.float32 return jnp.float32
def get_scale_shape( def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors in delayed scaling. """Get the shape for scale tensors in delayed scaling.
...@@ -77,6 +122,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -77,6 +122,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being scaled data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
The shape for scale tensors - (1,) The shape for scale tensors - (1,)
...@@ -84,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -84,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
del data_shape, is_colwise del data_shape, is_colwise
return (1,) return (1,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"x{i}" for i in range(input_rank))
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {})
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for block scaling mode. """Implementation for block scaling mode.
...@@ -113,8 +176,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -113,8 +176,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
""" """
return jnp.float8_e8m0fnu return jnp.float8_e8m0fnu
def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
"""Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
if len(data_shape) > 1:
# handle last dim
assert data_shape[-1] % scale_block_dim == 0
last = data_shape[-1] // scale_block_dim
scale_shape = (last,)
assert n_scale_blocks % last == 0
n_scale_blocks //= last
# handle middle dim, exclude first and last
for mid in reversed(data_shape[1:-1]):
scale_shape = (mid,) + scale_shape
assert n_scale_blocks % mid == 0
n_scale_blocks //= mid
scale_shape = (n_scale_blocks,) + scale_shape
else:
scale_shape = (n_scale_blocks,)
assert len(scale_shape) == len(
data_shape
), f"scale_shape {scale_shape}, data_shape {data_shape}"
return scale_shape
def get_scale_shape( def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""Get the shape for scale tensors in block scaling. """Get the shape for scale tensors in block scaling.
...@@ -122,6 +212,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -122,6 +212,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being quantized data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
The shape for scale tensors The shape for scale tensors
...@@ -135,38 +226,87 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -135,38 +226,87 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
block_x, block_y = self._block_dims block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment alignment_x, alignment_y = block_alignment
seq_axis = len(data_shape) - 2 if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
assert ( assert (
data_shape[seq_axis] % block_x == 0 0 < flatten_axis < len(data_shape)
), f"Input data of shape {data_shape} should be padded by {block_x} in axes={seq_axis}" ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
assert data_shape[flatten_axis - 1] % block_x == 0, (
f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
f" {flatten_axis - 1}"
)
assert ( assert (
data_shape[-1] % block_y == 0 data_shape[-1] % block_y == 0
), f"Input data of shape {data_shape} should be padded by {block_y} in axis -1" ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1 flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
n_block_seq = data_shape[seq_axis] // block_x flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
n_block_y = data_shape[-1] // block_y
n_flat_first_dim = reduce(operator.mul, data_shape[:seq_axis], 1) * n_block_seq assert flattened_first_dim % block_x == 0, (
f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
f" {data_shape} - should be divisible by block_x {block_x}"
)
assert flattened_last_dim % block_y == 0, (
"Flattened last dim - mutiplication of"
f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
f" divisible by block_y {block_y}"
)
# Padding n_block_x = int(flattened_first_dim / block_x)
n_flat_first_dim = ((n_flat_first_dim + alignment_x - 1) // alignment_x) * alignment_x n_block_y = int(flattened_last_dim / block_y)
n_block_y = ((n_block_y + alignment_y - 1) // alignment_y) * alignment_y
out_shape = () # padding
for i in range(seq_axis): n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x)
d = data_shape[i] n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y)
out_shape += (d,)
assert n_flat_first_dim % d == 0
n_flat_first_dim //= d
out_shape += (n_flat_first_dim, n_block_y) first_dim_scale_shape = self._apply_scale_shape_correction(
data_shape[:flatten_axis], n_block_x, block_x
)
last_dim_scale_shape = self._apply_scale_shape_correction(
data_shape[flatten_axis:], n_block_y, block_y
)
return out_shape return (*first_dim_scale_shape, *last_dim_scale_shape)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
# (Phuong: Map the NVTEScalingMode value to the ScalingMode Returns:
The Shardy rules for the scaling mode
"""
input_spec = [f"x{i}" for i in range(input_rank)]
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
rowwise_var = unique_var
colwise_var = f"{unique_var}_"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
# This implementation needs to be updated for different block dims.
assert self._block_dims == (1, 32)
return QuantizeShardyRules(
tuple(input_spec),
tuple(rowwise),
tuple(colwise),
{"block_size_rowwise": 32, "block_size_colwise": 32},
)
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -175,16 +315,14 @@ class ScalingMode(Enum): ...@@ -175,16 +315,14 @@ class ScalingMode(Enum):
"""Enumeration of tensor scaling modes with their corresponding metadata implementations. """Enumeration of tensor scaling modes with their corresponding metadata implementations.
This class defines the available scaling modes for tensor quantization: This class defines the available scaling modes for tensor quantization:
- NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NVTE_INVALID_SCALING: Invalid scaling mode - NO_SCALING: No scaling applied
- NVTE_NO_SCALING: No scaling applied
""" """
NVTE_DELAYED_TENSOR_SCALING = 0 NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
NVTE_MXFP8_1D_SCALING = 1 DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
NVTE_INVALID_SCALING = 2 MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
NVTE_NO_SCALING = 3
def _get_impl(self) -> ScalingModeMetadataImpl: def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode. """Get the implementation for this scaling mode.
...@@ -208,34 +346,54 @@ class ScalingMode(Enum): ...@@ -208,34 +346,54 @@ class ScalingMode(Enum):
""" """
return self._get_impl().get_scale_dtype() return self._get_impl().get_scale_dtype()
def get_scale_shape_2x(self, data_shape, is_padded=True) -> Tuple[Tuple[int]]: def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling. """Get shapes for both row-wise and column-wise scaling.
Args: Args:
data_shape: Shape of the data tensor data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape) Tuple of (rowwise_scale_shape, colwise_scale_shape)
""" """
rowwise_scale_shape = self.get_scale_shape( rowwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=False, is_padded=is_padded data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis
)
colwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis
) )
colwise_scale_shape = self.get_scale_shape(data_shape, is_colwise=True, is_padded=is_padded)
return (rowwise_scale_shape, colwise_scale_shape) return (rowwise_scale_shape, colwise_scale_shape)
def get_scale_shape(self, data_shape, is_colwise, is_padded=True) -> Tuple[int]: def get_scale_shape(
self, data_shape, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode. """Get the shape for scale tensors in this mode.
Args: Args:
data_shape: Shape of the data tensor data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns: Returns:
The shape for scale tensors The shape for scale tensors
""" """
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded) return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1
) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
Returns:
The Shardy rules for the scaling mode
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
def __eq__(self, other): def __eq__(self, other):
"""Compare this scaling mode with another. """Compare this scaling mode with another.
...@@ -273,8 +431,8 @@ class ScalingMode(Enum): ...@@ -273,8 +431,8 @@ class ScalingMode(Enum):
SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR # WAR
ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(), ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(),
} }
...@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod ...@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
import jax.numpy as jnp import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer from .dequantizer import Dequantizer
...@@ -84,6 +84,17 @@ class ScaledTensor(ABC): ...@@ -84,6 +84,17 @@ class ScaledTensor(ABC):
ValueError: If called on a tensor that doesn't support column-wise access ValueError: If called on a tensor that doesn't support column-wise access
""" """
@abstractmethod
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
...@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor): ...@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: The data type for dequantized values dq_dtype: The data type for dequantized values
_dq_func: The dequantization function _dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
""" """
data: jnp.ndarray data: jnp.ndarray
...@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor): ...@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: jnp.dtype dq_dtype: jnp.dtype
_dq_func: Callable _dq_func: Callable
is_colwise: bool is_colwise: bool
layout: str data_layout: str
flatten_axis: int = -1
def __post_init__(self): def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization. """Validates and adjusts the scale_inv shape after initialization.
...@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor): ...@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary. and quantization direction. Pads the scale_inv if necessary.
""" """
flatten_axis = (
len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
)
assert (
0 < flatten_axis < len(self.data.shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}"
if self.data_layout == "T":
flatten_axis = self.data.ndim - flatten_axis
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_scale_shape( expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis
) )
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
if self.scale_inv.shape != expected_scale_shape: if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, ( assert self.scale_inv.shape == expected_unpadded_scale_shape, (
...@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor): ...@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor):
A tuple containing (children, aux_data) for tree operations A tuple containing (children, aux_data) for tree operations
""" """
children = (self.data, self.scale_inv) children = (self.data, self.scale_inv)
aux_data = (self.scaling_mode, self.dq_dtype, self._dq_func, self.is_colwise, self.layout) aux_data = (
self.scaling_mode,
self.dq_dtype,
self._dq_func,
self.is_colwise,
self.data_layout,
self.flatten_axis,
)
return (children, aux_data) return (children, aux_data)
def dequantize(self): def dequantize(self):
...@@ -183,6 +214,45 @@ class ScaledTensor1x(ScaledTensor): ...@@ -183,6 +214,45 @@ class ScaledTensor1x(ScaledTensor):
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!") raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
# axis_names were given for N layout, so needs to be transpose for T layout
if self.data_layout == "T":
assert self.flatten_axis > 0
flatten_axis = -self.flatten_axis
axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis])
else:
axis_names = logical_axis_names
data = with_sharding_constraint_by_logical_axes(self.data, axis_names)
if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# TODO(Phuong): Handle padding !?
scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
else:
scale_inv = self.scale_inv
return ScaledTensor1x(
data=data,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype,
_dq_func=self._dq_func,
is_colwise=self.is_colwise,
data_layout=self.data_layout,
flatten_axis=self.flatten_axis,
)
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
...@@ -233,6 +303,27 @@ class ScaledTensor2x(ScaledTensor): ...@@ -233,6 +303,27 @@ class ScaledTensor2x(ScaledTensor):
""" """
return self.colwise_tensor return self.colwise_tensor
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
logical_axis_names
)
colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
logical_axis_names
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
@dataclass @dataclass
class ScaledTensorFactory: class ScaledTensorFactory:
...@@ -244,7 +335,13 @@ class ScaledTensorFactory: ...@@ -244,7 +335,13 @@ class ScaledTensorFactory:
@staticmethod @staticmethod
def create_1x( def create_1x(
data, scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, is_colwise=False, layout="N" data,
scale_inv,
scaling_mode,
dq_dtype=jnp.bfloat16,
is_colwise=False,
data_layout="N",
flatten_axis=-1,
): ):
"""Creates a single-scale quantized tensor. """Creates a single-scale quantized tensor.
...@@ -254,13 +351,16 @@ class ScaledTensorFactory: ...@@ -254,13 +351,16 @@ class ScaledTensorFactory:
scaling_mode: The scaling mode for quantization scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False) is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N") data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x instance A ScaledTensor1x instance
""" """
dq_func = Dequantizer.funcs.get(scaling_mode) dq_func = Dequantizer.funcs.get(scaling_mode)
return ScaledTensor1x(data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, layout) return ScaledTensor1x(
data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis
)
@staticmethod @staticmethod
def create_2x( def create_2x(
...@@ -270,7 +370,8 @@ class ScaledTensorFactory: ...@@ -270,7 +370,8 @@ class ScaledTensorFactory:
colwise_scale_inv, colwise_scale_inv,
scaling_mode, scaling_mode,
dq_dtype=jnp.bfloat16, dq_dtype=jnp.bfloat16,
layout="NN", data_layout="NN",
flatten_axis=-1,
): ):
"""Creates a double-scale quantized tensor. """Creates a double-scale quantized tensor.
...@@ -281,7 +382,8 @@ class ScaledTensorFactory: ...@@ -281,7 +382,8 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN") data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor2x instance A ScaledTensor2x instance
...@@ -294,7 +396,8 @@ class ScaledTensorFactory: ...@@ -294,7 +396,8 @@ class ScaledTensorFactory:
dq_dtype, dq_dtype,
dq_func, dq_func,
is_colwise=False, is_colwise=False,
layout=layout[0], data_layout=data_layout[0],
flatten_axis=flatten_axis,
) )
colwise_tensor = ScaledTensor1x( colwise_tensor = ScaledTensor1x(
colwise_data, colwise_data,
...@@ -303,7 +406,8 @@ class ScaledTensorFactory: ...@@ -303,7 +406,8 @@ class ScaledTensorFactory:
dq_dtype, dq_dtype,
dq_func, dq_func,
is_colwise=True, is_colwise=True,
layout=layout[1], data_layout=data_layout[1],
flatten_axis=flatten_axis,
) )
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
...@@ -315,8 +419,9 @@ class ScaledTensorFactory: ...@@ -315,8 +419,9 @@ class ScaledTensorFactory:
colwise_scale_inv: jnp.ndarray, colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode, scaling_mode: ScalingMode,
dq_dtype: jnp.dtype = jnp.bfloat16, dq_dtype: jnp.dtype = jnp.bfloat16,
layout: str = "NN", data_layout: str = "NN",
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE, q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
flatten_axis: int = -1,
): ):
"""Creates a scaled tensor based on the quantization axis. """Creates a scaled tensor based on the quantization axis.
...@@ -327,13 +432,13 @@ class ScaledTensorFactory: ...@@ -327,13 +432,13 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN") data_layout: The data_layout specification (default: "NN")
q_axis: The quantization axis (default: ROWWISE) q_layout: The quantization axis (default: ROWWISE)
Returns: Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
""" """
if q_axis == QuantizeAxis.ROWWISE_COLWISE: if q_layout == QuantizeLayout.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x( return ScaledTensorFactory.create_2x(
data, data,
scale_inv, scale_inv,
...@@ -341,12 +446,19 @@ class ScaledTensorFactory: ...@@ -341,12 +446,19 @@ class ScaledTensorFactory:
colwise_scale_inv, colwise_scale_inv,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
layout=layout, data_layout=data_layout,
flatten_axis=flatten_axis,
) )
is_colwise = q_axis == QuantizeAxis.COLWISE is_colwise = q_layout == QuantizeLayout.COLWISE
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
data, scale_inv, scaling_mode, dq_dtype, is_colwise=is_colwise, layout=layout[0] data,
scale_inv,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
) )
...@@ -360,24 +472,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . ...@@ -360,24 +472,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns: Returns:
The tensor with applied sharding constraints The tensor with applied sharding constraints
""" """
if isinstance(x, ScaledTensor1x): if isinstance(x, ScaledTensor):
return ScaledTensor1x( return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
data=with_sharding_constraint_by_logical_axes(x.data, logical_axis_names),
scale_inv=x.scale_inv,
scaling_mode=x.scaling_mode,
dq_dtype=x.dq_dtype,
_dq_func=x._dq_func,
is_colwise=x.is_colwise,
layout=x.layout,
)
if isinstance(x, ScaledTensor2x):
return ScaledTensor2x(
rowwise_tensor=with_sharding_constraint_by_logical_axes(
x.rowwise_tensor, logical_axis_names
),
colwise_tensor=with_sharding_constraint_by_logical_axes(
x.colwise_tensor, logical_axis_names
),
)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment