"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "e57c0c301c7fbe98d9c96a86890443cd83fc5bb9"
Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
......@@ -31,6 +31,9 @@
#include "transformer_engine/activation.h"
#include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
namespace transformer_engine {
namespace jax {
......@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x);
// Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
......@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode,
NVTE_Norm_Type norm_type,
JAXX_Scaling_Mode scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training);
......@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x);
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout);
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
......
......@@ -11,21 +11,13 @@
#include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
namespace {
bool is_gated(NVTE_Activation_Type act_type) {
return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU ||
act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU ||
act_type == NVTE_Activation_Type::SREGLU;
}
} // namespace
namespace transformer_engine {
namespace jax {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
......@@ -42,40 +34,59 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto n = input_dims.back();
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, act_len * n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
if (is_2x) {
output_tensor.set_columnwise_data(colwise_output, static_cast<DType>(out_dtype), output_shape);
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? scale_inv_buf
: colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
}
switch (act_type) {
......@@ -128,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<int64_t>("scaling_mode")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"),
FFI_CudaGraph_Traits);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x) {
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -153,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
auto dact_input_tensor =
TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(static_cast<NVTEScalingMode>(scaling_mode));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
......@@ -162,8 +173,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
if (is_2x) {
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype,
output_trans_shape);
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
......@@ -172,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
}
if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
......@@ -190,22 +202,25 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
bool is_dbias, int64_t act_enum) {
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
void *workspace = workspace_buf->untyped_data();
......@@ -213,67 +228,76 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto act_input_dims = act_input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
auto input_ranks = input_dims.size();
auto act_input_ranks = act_input_dims.size();
auto m = product(act_input_dims, 0, act_input_dims.size() - 1);
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated
auto n = act_input_dims.back();
auto input_shape = std::vector<size_t>{m, input_dims.back()};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{m, n};
auto dbias_shape = std::vector<size_t>{n};
// n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto act_len = act_input_dims[act_input_dims.size() - 2];
NVTE_CHECK(act_input_dims.back() == input_dims.back(),
"Shape mismatch between activation input and gradient input");
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto output_trans_shape = std::vector<size_t>{n * act_len, m};
auto dbias_shape = std::vector<size_t>{n * act_len};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
if (is_2x) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? scale_inv_buf
: colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x &&
is_gated(act_type)),
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) {
......@@ -361,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias")
.Attr<int64_t>("act_enum"),
.Attr<bool>("is_dbias"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
......@@ -15,47 +15,34 @@
namespace transformer_engine {
namespace jax {
constexpr static size_t MXFP8_BLOCK_SIZE = 32;
// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX)
Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr,
const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype,
uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr,
const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype,
uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms,
int32_t *dim_list_ptr, const int64_t &scaling_mode,
cudaStream_t stream) {
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms;
std::unique_ptr<int32_t[]> dim_list_host = std::make_unique<int32_t[]>(3 * num_gemms);
cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
Variadic_Result_Type output_list, int64_t num_gemms,
JAXX_Scaling_Mode scaling_mode, int64_t has_bias) {
// Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair:
// Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair:
// cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if (num_gemms <= 0) {
return ffi_with_cuda_error_check();
}
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t expected_output_size = num_gemms + 1;
size_t actual_input_size = input_list.size();
size_t actual_output_size = output_list.size();
NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
expected_input_size, actual_input_size);
NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
expected_output_size, actual_output_size);
bool trans_lhs = true;
bool trans_rhs = false;
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
......@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list;
int lhs_list_offset = 0;
int rhs_list_offset = num_gemms;
int lhs_sinv_list_offset = 2 * num_gemms;
int rhs_sinv_list_offset = 3 * num_gemms;
int bias_list_offset = 4 * num_gemms;
int out_list_offset = 0;
for (int i = 0; i < num_gemms; i++) {
size_t m = dim_list_host[i * 3];
size_t n = dim_list_host[i * 3 + 1];
size_t k = dim_list_host[i * 3 + 2];
Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value();
Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value();
Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value();
Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value();
Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value();
DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type());
DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
void *lhs_ptr = lhs_i.untyped_data();
void *rhs_ptr = rhs_i.untyped_data();
void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *out_ptr = out_i->untyped_data();
// Placeholder for bias since it can be empty
DType bias_dtype = DType::kFloat32;
void *bias_ptr = nullptr;
auto lhs_shape_ = lhs_i.dimensions();
auto rhs_shape_ = rhs_i.dimensions();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t m = lhs_shape_[1];
size_t n = rhs_shape_[1];
size_t k = lhs_shape_[2];
auto lhs_shape = std::vector<size_t>{m, k};
auto rhs_shape = std::vector<size_t>{n, k};
......@@ -90,54 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == NVTE_NO_SCALING || scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
auto lhs_i = TensorWrapper(static_cast<void *>(lhs_ptr), lhs_shape, lhs_dtype, nullptr,
nullptr, reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i = TensorWrapper(static_cast<void *>(rhs_ptr), rhs_shape, rhs_dtype, nullptr,
nullptr, reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t sinv_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_shape[0] = m;
lhs_sinv_shape[1] = sinv_k;
rhs_sinv_shape[0] = n;
rhs_sinv_shape[1] = sinv_k;
if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
float *amax_dptr = nullptr;
float *scale_dptr = nullptr;
auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Note: the scale_inv array should have been swizzled in Python before lowering
TensorWrapper lhs_i(NVTE_MXFP8_1D_SCALING);
TensorWrapper rhs_i(NVTE_MXFP8_1D_SCALING);
lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0,
lhs_sinv_shape);
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0,
rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
for (int i = 0; i < 2; i++) {
lhs_sinv_shape[i] = lhs_sinv_shape_[i];
rhs_sinv_shape[i] = rhs_sinv_shape_[i];
}
NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode);
TensorWrapper lhs_i_(nvte_scaling_mode);
TensorWrapper rhs_i_(nvte_scaling_mode);
lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else {
NVTE_ERROR("Unsupported scaling mode: ", scaling_mode);
NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
}
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
lhs_ptr += m * k * lhs_dtype_bytes;
rhs_ptr += n * k * rhs_dtype_bytes;
out_ptr += m * n * out_dtype_bytes;
lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes;
auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype);
void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0};
auto pre_gelu_shape = std::vector<size_t>{0};
if (bias_ptr != nullptr) bias_shape[0] = n;
if (has_bias) {
auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i);
Buffer_Type bias_i = bias_i_get.value();
bias_ptr = bias_i.untyped_data();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type());
bias_shape[0] = n;
}
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes;
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
out_wrapper_list.push_back(std::move(out_i));
out_wrapper_list.push_back(std::move(out_i_));
bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
......@@ -148,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
out_list.push_back(out_wrapper_list.back().data());
}
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) {
auto workspace_i =
......@@ -165,49 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
return ffi_with_cuda_error_check();
}
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten,
Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten,
Buffer_Type dim_list, Result_Type out_flatten,
Result_Type workspace_flatten, int64_t num_gemms, int64_t scaling_mode) {
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv_flatten.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv_flatten.untyped_data());
auto bias_ptr = reinterpret_cast<uint8_t *>(bias_flatten.untyped_data());
auto dim_list_ptr = reinterpret_cast<int32_t *>(dim_list.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type());
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type());
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(out_flatten->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace_flatten->untyped_data());
auto workspace_size = workspace_flatten->dimensions().back() / num_streams;
return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype,
rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype,
workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode,
stream);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_flatten
.Arg<Buffer_Type>() // lhs_sinv_flatten
.Arg<Buffer_Type>() // rhs_flatten
.Arg<Buffer_Type>() // rhs_sinv_flatten
.Arg<Buffer_Type>() // bias_flatten
.Arg<Buffer_Type>() // dim_list
.Ret<Buffer_Type>() // out_flatten
.Ret<Buffer_Type>() // workspace_flatten
.RemainingArgs() // input list
.RemainingRets() // output list
.Attr<int64_t>("num_gemms")
.Attr<int64_t>("scaling_mode"),
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("has_bias"),
FFI_CudaGraph_Traits);
} // namespace jax
......
......@@ -34,11 +34,34 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
enum class QuantizeAxis {
enum class QuantizeLayout {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
};
enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2,
};
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
break;
default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break;
}
}
} // namespace jax
} // namespace transformer_engine
......@@ -14,7 +14,8 @@ namespace jax {
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode,
NVTE_Norm_Type norm_type,
JAXX_Scaling_Mode scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode);
auto output_tensor = TensorWrapper(_scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && _scaling_mode == NVTE_MXFP8_1D_SCALING) {
if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
int temp = 1;
output_tensor.set_columnwise_data(static_cast<void *>(&temp), out_dtype, input_shape);
}
......@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr);
} else {
NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma,
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma,
"rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), epsilon, output_tensor.data(),
rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma,
......@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf,
int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, int scaling_mode, bool is_2x) {
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
......@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data();
auto _scaling_mode = static_cast<NVTEScalingMode>(scaling_mode);
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
......@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto output_tensor = TensorWrapper(_scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (is_fp8_dtype(out_dtype)) {
......@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
scale_inv_buf->dimensions().back()});
}
if (_scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
......@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(),
workspace_tensor.data(), num_sm, zero_centered_gamma, stream);
} else {
NVTE_CHECK(scaling_mode != NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING || !zero_centered_gamma,
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || !zero_centered_gamma,
"rmsnorm doesn't support zero_centered_gamma.");
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), _epsilon, output_tensor.data(),
rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma,
......@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<bool>("zero_centered_gamma")
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<int64_t>("scaling_mode")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"),
FFI_CudaGraph_Traits);
......
......@@ -138,17 +138,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("RMSNorm", NVTE_Norm_Type::RMSNorm)
.export_values();
pybind11::enum_<NVTEScalingMode>(m, "NVTE_Scaling_Mode", pybind11::module_local())
.value("NVTE_DELAYED_TENSOR_SCALING", NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING)
.value("NVTE_MXFP8_1D_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
pybind11::enum_<JAXX_Scaling_Mode>(m, "JAXX_Scaling_Mode", pybind11::module_local())
.value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeAxis>(m, "QuantizeAxis",
pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE)
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values();
}
......
......@@ -13,7 +13,9 @@ namespace transformer_engine {
namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
......@@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
int temp = 0;
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto output_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), output_shape, out_dtype);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
}
if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) {
auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
}
if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
TensorWrapper dummy_workspace;
nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
......@@ -42,10 +71,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum,
int64_t quantize_axis_enum, bool is_dbias) {
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
bool is_dbias, int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -54,8 +83,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data();
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto const quantize_axis = static_cast<QuantizeAxis>(quantize_axis_enum);
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
......@@ -63,9 +91,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
void *workspace = workspace_buf->untyped_data();
auto input_dims = input_buf.dimensions();
int64_t input_ndim = input_dims.size();
if (flatten_axis < 0) flatten_axis += input_ndim;
NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");
auto workspace_dims = workspace_buf->dimensions();
auto m = product(input_dims, 0, input_dims.size() - 1);
auto n = input_dims.back();
auto m = product(input_dims, 0, flatten_axis);
auto n = product(input_dims, flatten_axis, input_ndim);
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
......@@ -73,39 +105,58 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
}
if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? scale_inv_buf
: colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
......@@ -132,9 +183,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode")
.Attr<int64_t>("q_axis")
.Attr<bool>("is_dbias"),
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
......
......@@ -15,7 +15,11 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import QuantizerSet, noop_quantizer_set
from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
)
def dense(
......@@ -23,6 +27,8 @@ def dense(
kernel: jnp.ndarray,
bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -48,12 +54,12 @@ def dense(
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense(x, kernel, bias, contracting_dims, quantizer_set)
output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def _dense(x, kernel, bias, contracting_dims, quantizer_set):
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
......@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set)
output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
)
return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Tuple of (output, context) for backward pass
"""
x_contracting_dims, k_contracting_dims = contracting_dims
casted_x = tex.quantize(x, quantizer_set.x)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN
output = tex.gemm(
......@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
use_bias = bias is not None
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
......@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
kernel.shape,
use_bias,
quantizer_set,
flatten_axis_k,
)
return output, ctx
def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns:
Tuple of gradients with respect to inputs
"""
......@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
kernel_shape,
use_bias,
quantizer_set,
flatten_axis_k,
) = ctx
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
)
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN
# x_non_contracting_dims
......@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set
......
......@@ -13,7 +13,6 @@ import jax.numpy as jnp
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
......@@ -26,8 +25,14 @@ from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available
from ..cpp_extensions import (
is_softmax_kernel_available,
jax_scaled_softmax,
jax_scaled_masked_softmax,
jax_scaled_upper_triang_masked_softmax,
)
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -167,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype = inputs.dtype
logits = inputs
if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
# use primitives
if is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
):
if bias is not None:
logits = logits + bias.astype(input_dtype)
......@@ -179,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
mask_ = None
outputs = softmax(logits, mask_, self.scale_factor, self.softmax_type)
# use default jax based implementation
else:
attention_bias = None
if mask is not None:
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, -1e10),
jnp.full(mask.shape, 0.0),
)
attention_bias = attention_bias.astype(input_dtype)
if bias is not None:
attention_bias = _combine_biases(attention_bias, bias)
if attention_bias is not None:
logits = logits + attention_bias.astype(input_dtype)
logits = logits + bias.astype(input_dtype)
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype
):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
if self.softmax_type is SoftmaxType.SCALED:
outputs = jax_scaled_softmax(logits, self.scale_factor)
elif self.softmax_type is SoftmaxType.SCALED_MASKED:
outputs = jax_scaled_masked_softmax(logits, mask, self.scale_factor)
elif self.softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
outputs = jax_scaled_upper_triang_masked_softmax(logits, self.scale_factor)
else:
outputs = jax_nn.softmax(logits * self.scale_factor)
raise ValueError(
f"Unsupported softmax type: {self.softmax_type}. softmax_type must be [SCALED,"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
)
assert input_dtype == outputs.dtype
return outputs
......@@ -360,7 +356,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
......@@ -406,6 +402,10 @@ class DenseGeneral(TransformerEngineBase):
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
-----------------------
......@@ -429,6 +429,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
def __post_init__(self):
if self.kernel_init is None:
......@@ -460,29 +461,35 @@ class DenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
if self.kernel_axes:
assert len(kernel_shape) == len(self.kernel_axes), (
"Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
)
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
else:
bias = None
quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis)))
y = dense(
inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set
inputs,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation:
......@@ -491,20 +498,14 @@ class DenseGeneral(TransformerEngineBase):
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_init_shape = (
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
......@@ -527,7 +528,6 @@ class DenseGeneral(TransformerEngineBase):
y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype
y = y.reshape(*inputs.shape[: self.axis], *features)
return y
......@@ -678,6 +678,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
assert self.axis == -1, "Only support axis = =-1 at this moment"
input_dtype = inputs.dtype
ln_output = None
......@@ -692,10 +693,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type,
(features,),
......@@ -731,17 +729,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim)
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
contract_ind = tuple(range(0, len(axis)))
......@@ -756,11 +749,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set)
z = dense(
y,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (
......@@ -768,20 +769,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_init_shape = (
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
......@@ -803,8 +798,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
......@@ -814,7 +808,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
z = z / self.depth_scaling
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
z = z.reshape(*inputs.shape[: self.axis], *features)
# z = z.reshape(*inputs.shape[: self.axis], *features)
return z, ln_output # dense_output, layer_norm_output
......@@ -989,6 +983,8 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1")
......@@ -1027,7 +1023,6 @@ class LayerNormMLP(TransformerEngineBase):
)
# LayerNorm
if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1]
......@@ -1071,7 +1066,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes(
"wi_kernel",
kernel_1_init,
......@@ -1081,13 +1076,10 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
axes=self.kernel_axes_1,
)
kernel_1_compute_shape = (
reduce(operator.mul, [y.shape[ax] for ax in axis], 1),
num_activations * self.intermediate_dim,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
......@@ -1098,26 +1090,20 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
axes=self.kernel_axes_2,
)
kernel_2_compute_shape = (
self.intermediate_dim,
reduce(operator.mul, hidden_size_tuple, 1),
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
if self.use_bias:
bias_1_shape = num_activations * self.intermediate_dim
bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes(
"wi_bias",
self.bias_init,
bias_1_shape,
self.dtype,
axes=self.bias_axes_1,
)
bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
......@@ -1126,8 +1112,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2_shape,
self.dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
else:
bias_1 = None
bias_2 = None
......@@ -1136,8 +1121,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_mlp(
y,
scale,
......@@ -1150,6 +1133,8 @@ class LayerNormMLP(TransformerEngineBase):
norm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name,
activation_type=normalized_acts,
......@@ -1170,6 +1155,7 @@ class LayerNormMLP(TransformerEngineBase):
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
)
else:
......@@ -1178,35 +1164,31 @@ class LayerNormMLP(TransformerEngineBase):
y,
kernel_1,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
)
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = (
kernel_1_compute_shape[0],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_shape = (
kernel_1_each_shape[0],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_each_shape = (
kernel_1_each_shape[0],
wi_lora_a_kernel_each_shape = (
kernel_1_each_shape[: len(axis)],
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
wi_lora_a_kernel = nn_partitioning.param_with_axes(
"wi_lora_a_kernel",
kernel_1_init,
num_activations,
-1,
wi_lora_a_kernel_init_each_shape,
-2,
wi_lora_a_kernel_each_shape,
self.dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = (
......@@ -1227,7 +1209,7 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation(
y,
axis,
num_activations * self.intermediate_dim,
(num_activations, self.intermediate_dim),
wi_lora_a_kernel,
wi_lora_b_kernel,
self.low_rank_adaptation_alpha,
......@@ -1241,11 +1223,12 @@ class LayerNormMLP(TransformerEngineBase):
z = activation(x, normalized_acts)
else:
activations = []
x = jnp.split(x, num_activations, axis=-1)
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = reduce(operator.mul, activations)
z = jnp.squeeze(z, axis=-2)
z = z.astype(input_dtype)
z = nn.Dropout(
......@@ -1259,7 +1242,12 @@ class LayerNormMLP(TransformerEngineBase):
# DenseGeneral 2
out = dense(
z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set
z,
kernel_2,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set,
)
if self.enable_low_rank_adaptation:
......
......@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if mask is not None:
mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
if mask is not None:
return SoftmaxType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.CAUSAL_MASK:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]:
if mask is not None:
return SoftmaxType.SCALED_MASKED, mask
if attn_mask_type is AttnMaskType.NO_MASK:
return SoftmaxType.SCALED, mask
raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type="
......@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
.. note:: THD format only supports 'padding' or 'causal_padding' mask type.
attn_mask_type mask/sequence_descriptor SWA softmax type
--------------------------------------------------------------------------------------------
no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
......
......@@ -33,10 +33,9 @@ def layernorm_dense(
norm_type: str = "layernorm",
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes: Tuple[str, ...] = None,
# The logic axes of sharding constraint to the dot input.
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
......@@ -56,6 +55,7 @@ def layernorm_dense(
epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: Set of quantizers for different tensor types
Returns:
......@@ -78,6 +78,7 @@ def layernorm_dense(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
)
return output
......@@ -91,6 +92,7 @@ def layernorm_dense(
7,
8,
9,
10,
),
)
def _layernorm_dense(
......@@ -104,6 +106,7 @@ def _layernorm_dense(
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
......@@ -139,6 +142,7 @@ def _layernorm_dense(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
)
return output
......@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
......@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
......@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
norm_type,
quantizer_set.x,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
......@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims,
use_bias,
quantizer_set,
flatten_axis,
)
return output, ctx
......@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
kernel_axes,
ctx,
grad,
):
......@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
k_contracting_dims_in_fwd,
use_bias,
quantizer_set,
flatten_axis,
) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes)
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple(
......@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
(x_constracting_dim, g_constracting_dim),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad,
x,
......
......@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .sharding import get_non_contracting_logical_axes
def layernorm_mlp(
......@@ -37,6 +38,8 @@ def layernorm_mlp(
norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
kernel_1_axes: Tuple[str, ...] = None,
kernel_2_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
......@@ -66,6 +69,8 @@ def layernorm_mlp(
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
kernel_1_axes: Logical axes for sharding the first weight matrix
kernel_2_axes: Logical axes for sharding the second weight matrix
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
......@@ -109,6 +114,8 @@ def layernorm_mlp(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -117,7 +124,7 @@ def layernorm_mlp(
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -132,6 +139,8 @@ def _layernorm_mlp(
norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
kernel_1_axes: Tuple[str, ...],
kernel_2_axes: Tuple[str, ...],
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
......@@ -179,6 +188,8 @@ def _layernorm_mlp(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
Returns:
Tuple of (output, context) for automatic differentiation
"""
del kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len * intermediate)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in)
assert len(kernel_1.shape) == 2
assert len(kernel_1.shape) == 3
assert len(kernel_2.shape) == 2
assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type)
assert kernel_1.shape[-2] == len(activation_type)
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None
......@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
norm_type,
quantizer=ffn1_quantizer_set.x,
)
casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm(
......@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_1.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1:
bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
......@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims),
)
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2:
bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
......@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
ctx,
grad,
......@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
(
x,
mu,
......@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_2 = tuple(
g_contracting_dims_2 = tuple(
range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
)
# k_non_contracting_dims
k_constracting_dim_2 = tuple(
k_contracting_dims_2 = tuple(
dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
)
......@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2,
(g_constracting_dim_2, k_constracting_dim_2),
(g_contracting_dims_2, k_contracting_dims_2),
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
x_constracting_dim = g_constracting_dim = tuple(
x_contracting_dims = g_contracting_dims = tuple(
range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
)
......@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
wgrad_2 = tex.gemm(
colwise_casted_act_out,
casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
(x_contracting_dims, g_contracting_dims),
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
dgrad_2,
......@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1 = tuple(
range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim)
dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
)
# k_non_contracting_dims
k_constracting_dim_1 = tuple(
k_contracting_dims_1 = tuple(
dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
)
......@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1,
(g_constracting_dim_1, k_constracting_dim_1),
(g_contracting_dims_1, k_contracting_dims_1),
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm(
colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
(x_contracting_dims, g_contracting_dims),
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad_1,
x,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Praxis related Modules"""
from .module import FusedSoftmax, LayerNorm
from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer
from .transformer import DotProductAttention, MultiHeadAttention
from .transformer import RelativePositionBiases, TransformerLayer
from ..flax.transformer import TransformerLayerType
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules
"""
from dataclasses import field
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union
from praxis import pax_fiddle
from praxis.base_layer import init_var
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams, WeightHParamsCollection
from praxis.layers import flax_adapter
from praxis.pytypes import JTensor
from ..fp8 import FP8Helper
from ..flax.module import DenseGeneral, LayerNormDenseGeneral
from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP
from ..flax.module import Softmax
from ..softmax import SoftmaxType
def _generate_ln_scale_init(scale_init):
if scale_init is not None:
return TransformerEngineBaseLayer.generate_params_init("scale", scale_init)
return scale_init
class TransformerEngineBaseLayer(BaseLayer):
"""TransformerEngineBaseLayer"""
logical_axes_rules: Tuple[Tuple, ...] = None
@staticmethod
def generate_params_init(name: str, initializer: WeightInit):
"""generate_params_init"""
def kernel_init(key, shape, dtype):
wp = WeightHParams(shape=shape, init=initializer, dtype=dtype)
return init_var(wp, key, name)
return kernel_init
def create_layer(self, name, flax_module_cls):
"""create_layer"""
fp8_collection_map = {
FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.OVERWRITE_WITH_GRADIENT,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION,
]
}
flax_module_p = pax_fiddle.Config(
flax_adapter.FlaxModuleAdapter,
module_factory_method=flax_module_cls,
logical_axes_rules=self.logical_axes_rules,
var_collection_map=fp8_collection_map,
ici_mesh_shape=self.ici_mesh_shape,
dcn_mesh_shape=self.dcn_mesh_shape,
mesh_axis_names=self.mesh_axis_names,
)
self.create_child(name, flax_module_p.clone())
class LayerNorm(TransformerEngineBaseLayer):
"""LayerNorm"""
epsilon: float = 1e-6
layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
ln_cls = partial(
flax_LayerNorm,
epsilon=self.epsilon,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", self.bias_init),
bias_axes=self.bias_axes,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("layer_norm", ln_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.layer_norm(x)
class FusedSoftmax(TransformerEngineBaseLayer):
"""FusedSoftmax"""
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
def setup(self) -> None:
"""setup"""
super().setup()
fused_softmax_cls = partial(
Softmax, scale_factor=self.scale_factor, softmax_type=self.softmax_type
)
self.create_layer("fused_softmax", fused_softmax_cls)
def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JTensor:
"""__call__"""
return self.fused_softmax(x, mask, bias)
class Linear(TransformerEngineBaseLayer):
"""Linear"""
out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
dense_general_cls = partial(
DenseGeneral,
features=self.out_features,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("linear", dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.linear(x)
class LayerNormLinear(TransformerEngineBaseLayer):
"""LayerNormLinear"""
out_features: int = 512
enable_layernorm: bool = True
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
depth_scaling: float = None
def setup(self) -> None:
"""setup"""
super().setup()
ln_dense_general_cls = partial(
LayerNormDenseGeneral,
features=self.out_features,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init
),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling,
)
self.create_layer("ln_linear", ln_dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.ln_linear(x)
class LayerNormMLP(TransformerEngineBaseLayer):
"""LayerNormMLP"""
intermediate_dim: int = 2048
enable_layernorm: bool = True
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",)
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
ln_mlp_cls = partial(
flax_LayerNormMLP,
intermediate_dim=self.intermediate_dim,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init
),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes_1=self.kernel_axes_1,
kernel_axes_2=self.kernel_axes_2,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes_1=self.bias_axes_1,
bias_axes_2=self.bias_axes_2,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output,
activations=self.activations,
intermediate_dropout_rate=self.intermediate_dropout_rate,
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("ln_mlp", ln_mlp_cls)
def __call__(self, x: JTensor, deterministic: bool = False) -> JTensor:
"""__call__"""
return self.ln_mlp(x, deterministic)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from dataclasses import field
from functools import partial
from typing import Optional, Sequence, Tuple
import warnings
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer
from ..flax.transformer import TransformerLayerType
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
from ..attention import AttnBiasType, AttnMaskType
class RelativePositionBiases(TransformerEngineBaseLayer):
"""RelativePositionBiases"""
num_buckets: int = 32
max_distance: int = 128
num_attention_heads: int = 64
embedding_init: WeightInit = None
embedding_axes: Tuple[str, ...] = ()
@staticmethod
def generate_embedding_init(init, num_attention_heads, num_buckets):
"""generate_embedding_init"""
embedding_init = init
if embedding_init is None:
rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
return embedding_init
def setup(self) -> None:
"""setup"""
super().setup()
embedding_init = RelativePositionBiases.generate_embedding_init(
self.embedding_init, self.num_attention_heads, self.num_buckets
)
rpb_cls = partial(
flax_RelativePositionBiases,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
num_attention_heads=self.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
embedding_axes=self.embedding_axes,
dtype=self.dtype,
)
self.create_layer("relative_position_bias", rpb_cls)
def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
"""__call__"""
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)
class DotProductAttention(TransformerEngineBaseLayer):
"""DotProductAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f"{self.head_dim=}"
assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
dpa_cls = partial(
flax_DotProductAttention,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)
self.create_layer("dot_product_attention", dpa_cls)
def __call__(
self,
query: JTensor,
key: JTensor,
value: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
deterministic: bool = False,
) -> JTensor:
"""__call__"""
return self.dot_product_attention(
query, key, value, mask, bias, deterministic=deterministic
)
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.0
dropout_rng_name: str = "dropout"
input_layernorm: bool = True
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None
# Deprecated parameters
num_heads: Optional[int] = None
dropout_rate: Optional[float] = None
output_layernorm: Optional[bool] = None
apply_residual_connection_post_layernorm: Optional[bool] = None
fuse_qkv: Optional[bool] = None
def __post_init__(self):
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.",
DeprecationWarning,
)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.",
DeprecationWarning,
)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning,
)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.",
DeprecationWarning,
)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f"{self.head_dim=}"
assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
mha_cls = partial(
flax_MultiHeadAttention,
dtype=self.dtype,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
input_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
return_layernorm_output=self.return_layernorm_output,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits,
window_size=self.window_size,
)
self.create_layer("multi_head_attn", mha_cls)
def __call__(
self,
inputs_q: JTensor,
inputs_kv: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
decode: bool = False,
deterministic: bool = False,
) -> JTensor:
"""__call__"""
return self.multi_head_attn(
inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic
)
class TransformerLayer(TransformerEngineBaseLayer):
"""TransformerLayer"""
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: Optional[int] = None
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = "causal"
self_attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
relative_embedding_flax_module = None
if self.enable_relative_embedding and self.relative_embedding is not None:
assert self.relative_embedding.num_attention_heads == self.num_attention_heads, (
"TransformerLayer.relative_embedding.num_attention_heads shoule be"
"the same as TransformerLayer.num_attention_heads."
)
embedding_init = RelativePositionBiases.generate_embedding_init(
self.relative_embedding.embedding_init,
self.relative_embedding.num_attention_heads,
self.relative_embedding.num_buckets,
)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=self.relative_embedding.num_buckets,
max_distance=self.relative_embedding.max_distance,
num_attention_heads=self.relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
embedding_axes=self.relative_embedding.embedding_axes,
dtype=self.relative_embedding.dtype,
)
transformerlayer_cls = partial(
flax_TransformerLayer,
dtype=self.dtype,
hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
hidden_dropout=self.hidden_dropout,
hidden_dropout_dims=self.hidden_dropout_dims,
attention_dropout=self.attention_dropout,
intermediate_dropout=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
dropout_rng_name=self.dropout_rng_name,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", self.params_init
),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", self.params_init
),
mlp_activations=self.mlp_activations,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
self_attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
window_size=self.window_size,
)
self.create_layer("transformerlayer", transformerlayer_cls)
def __call__(
self,
inputs: JTensor,
encoded: JTensor = None,
attention_mask: JTensor = None,
encoder_decoder_mask: JTensor = None,
deterministic: bool = False,
decode: bool = False,
max_decode_length: bool = None,
) -> JTensor:
"""__call__"""
return self.transformerlayer(
inputs,
encoded,
attention_mask,
encoder_decoder_mask,
deterministic,
decode,
max_decode_length,
)
......@@ -57,26 +57,35 @@ class Dequantizer:
data = scaled_tensor.data.astype(jnp.float32)
data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
flatten_axis = scaled_tensor.flatten_axis
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False
data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
data = data.reshape(
*data_shape[:-2],
scale_shape[-2],
int(data_shape[-2] / scale_shape[-2]),
*data_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale = jnp.expand_dims(scale, axis=(-1, -3))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape
)
funcs = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.NVTE_MXFP8_1D_SCALING: _dq_func_block_scaling,
ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling,
}
@staticmethod
......
......@@ -27,7 +27,14 @@ from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex
__all__ = ["QuantizeConfig", "fp8_autocast", "is_fp8_available", "update_collections"]
__all__ = [
"QuantizeConfig",
"fp8_autocast",
"is_fp8_available",
"update_collections",
"get_delayed_scaling",
"NVTE_FP8_COLLECTION_NAME",
]
_is_fp8_available = None
_reason_for_no_fp8 = ""
......@@ -87,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
A tuple of (bool, str) indicating support and any error message
"""
gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _check_block_scaling_fp8_support(gpu_arch)
return (False, "Unsupported scaling_mode!")
def is_fp8_available(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
gpu_id=None,
) -> Tuple[bool, str]:
"""Check if FP8 is available for the given scaling mode and GPU.
......@@ -172,37 +179,12 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
ValueError: If the recipe type is not supported
"""
if isinstance(fp8_recipe, recipe.DelayedScaling):
return ScalingMode.NVTE_DELAYED_TENSOR_SCALING
return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.NVTE_MXFP8_1D_SCALING
return ScalingMode.MXFP8_1D_SCALING
raise ValueError("Invalid fp8_recipe!")
def update_collections(new: Collection, original: Collection) -> Collection:
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
class QuantizeConfig:
"""Configuration class for quantization settings.
......@@ -227,7 +209,7 @@ class QuantizeConfig:
INITIALIZED = False
MARGIN: float = 0.0
COLLECTION_NAME: str = "quantize_meta"
COLLECTION_NAME: str = "fp8_metas"
FP8_FORMAT: recipe.Format = recipe.Format.HYBRID
FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0]
BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1]
......@@ -235,7 +217,7 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NVTE_NO_SCALING
SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling
AMAX_HISTORY_LEN: int = 1024
......@@ -271,11 +253,11 @@ class QuantizeConfig:
cls.MARGIN = 0.0
cls.FP8_FORMAT = recipe.Format.HYBRID
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.FP8_2X_ACC_FPROP = False
cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NVTE_NO_SCALING
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.IF_QUANTIZE_2X = False
# DelayedScaling
cls.AMAX_HISTORY_LEN = 1024
......@@ -414,3 +396,56 @@ def fp8_autocast(
yield
finally:
Config.finalize()
def get_delayed_scaling():
r"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo = (
"max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent"
)
return recipe.DelayedScaling(
margin=int(QuantizeConfig.MARGIN),
fp8_format=QuantizeConfig.FP8_FORMAT,
amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN,
amax_compute_algo=amax_compute_algo,
)
def update_collections(new: Collection, original: Collection) -> Collection:
r"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert isinstance(original, (dict, FrozenDict))
assert isinstance(new, (dict, FrozenDict))
frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
for key in new:
if key in frozen_original:
frozen_original, _ = frozen_original.pop(key)
new_coll = FrozenDict({**new, **frozen_original})
if not isinstance(original, FrozenDict):
new_coll = new_coll.unfreeze()
return new_coll
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
......@@ -14,7 +14,7 @@ from typing import Union, Optional
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
......@@ -24,7 +24,7 @@ from .helper import (
)
__all__ = [
"QuantizeAxis",
"QuantizeLayout",
"Quantizer",
"QuantizerSet",
"DelayedScaleQuantizer",
......@@ -45,12 +45,12 @@ class Quantizer(ABC):
Attributes:
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both)
q_layout: The quantization axis (row-wise, column-wise, or both)
"""
q_dtype: jnp.dtype
scaling_mode: ScalingMode
q_axis: QuantizeAxis
q_layout: QuantizeLayout
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
......@@ -59,7 +59,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations
"""
children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
@classmethod
......@@ -85,30 +85,31 @@ class Quantizer(ABC):
Returns:
True if using both row-wise and column-wise quantization
"""
return self.q_axis == QuantizeAxis.ROWWISE_COLWISE
return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
@abstractmethod
def get_layout(self) -> str:
"""Get the data layout.
def get_data_layout(self) -> str:
"""Get the data data_layout.
Returns:
Data layout in string format
Data data_layout in string format
"""
@abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Core quantization function to be implemented by subclasses.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None):
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1):
"""Quantize a tensor using the internal _quantize_func().
Args:
......@@ -116,21 +117,26 @@ class Quantizer(ABC):
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise:
return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
return self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return self._quantize_func(x, dq_dtype=dq_dtype)
return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def get_scale_shapes(self, data_shape, is_padded=True):
def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1):
"""Get shapes for scale tensors.
Args:
......@@ -140,7 +146,7 @@ class Quantizer(ABC):
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded)
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
def get_scale_dtype(self):
"""Get the data type for scale tensors.
......@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
......@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
def get_layout(self) -> str:
"""Get the data layout string.
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data layout in string format
Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
layout = "NT"
if self.q_axis == QuantizeAxis.ROWWISE_COLWISE:
return layout
if self.q_axis == QuantizeAxis.ROWWISE:
return layout[0]
if self.q_axis == QuantizeAxis.COLWISE:
return layout[1]
raise ValueError(f"Invalid q_axis: {self.q_axis}")
def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
data_layout = "NT"
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return data_layout
if self.q_layout == QuantizeLayout.ROWWISE:
return data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE:
return data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}")
def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
......@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None):
def quantize(
self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
):
"""Quantize a tensor using the internal _quantize_func().
Args:
......@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x())
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x())
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = None
if is_colwise:
colwise_tensor = ScaledTensorFactory.create_1x(
data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))),
data=jnp.transpose(
rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
),
scale_inv=rowwise_tensor.scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
is_colwise=True,
layout="T",
data_layout="T",
flatten_axis=flatten_axis,
)
if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
......@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
q_layout: Quantization axis (default: ROWWISE_COLWISE)
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_layout(self) -> str:
"""Get the data layout string.
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data layout in string format
Data data_layout in string format
"""
if self.is_2x2x():
return "NN"
return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
assert (
0 <= flatten_axis < x.ndim
), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False)
scale_shape = self.scaling_mode.get_scale_shape(
x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape(
*x_shape[:-2],
scale_shape[-2],
int(x_shape[-2] / scale_shape[-2]),
*x_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*x_shape[flatten_axis:-1],
scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]),
)
amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True)
amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
scales = amax.astype(jnp.float32) / MAX
......@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
self.scaling_mode,
is_colwise=is_colwise,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
def _cast_to_e8m0_with_rounding_up(self, scales):
......@@ -500,8 +530,8 @@ class QuantizerFactory:
"""
quantizer_type_map = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScaleQuantizer,
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
}
@staticmethod
......@@ -509,7 +539,7 @@ class QuantizerFactory:
n_quantizers: int = 1,
scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None,
q_axis: QuantizeAxis = None,
q_layout: QuantizeLayout = None,
**kwargs,
) -> Quantizer:
"""Create one or more quantizers with specified parameters.
......@@ -518,15 +548,17 @@ class QuantizerFactory:
n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use
q_dtype: Quantization data type
q_axis: Quantization axis
q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer or tuple of quantizers
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING
if scaling_mode in (ScalingMode.NVTE_NO_SCALING, ScalingMode.NVTE_INVALID_SCALING):
assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
# import pdb; pdb.set_trace()
if scaling_mode == ScalingMode.NO_SCALING:
quantizers = [None] * n_quantizers
else:
quantizers = []
......@@ -534,7 +566,7 @@ class QuantizerFactory:
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append(
quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs
q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
)
)
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
......@@ -554,11 +586,11 @@ class QuantizerFactory:
A QuantizerSet instance
"""
if is_2x2x:
q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else:
q_axis_x = QuantizeAxis.ROWWISE
q_axis_kernel = QuantizeAxis.COLWISE
q_axis_dgrad = None
q_layout_x = QuantizeLayout.ROWWISE
q_layout_kernel = QuantizeLayout.COLWISE
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set")
......@@ -577,9 +609,11 @@ class QuantizerFactory:
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x)
q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad)
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x)
q_kernel = QuantizerFactory.create(
1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel
)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod
......@@ -618,4 +652,4 @@ class QuantizerFactory:
return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NVTE_NO_SCALING)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING)
......@@ -16,11 +16,33 @@ from typing import Tuple, Dict
from functools import reduce
import operator
from jax.experimental.custom_partitioning import CompoundFactor
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode
__all__ = ["ScalingMode"]
__all__ = ["QuantizeShardyRules", "ScalingMode"]
@dataclass
class QuantizeShardyRules:
"""Information necessary to shard scale tensors with Shardy.
Attributes:
input_spec: Specification for the input axes
rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
the axes in `input_spec`
colwise_rule: Likewise for the column-wise scale tensor.
factor_sizes: For block scaling, contains the block size factor, which is
used in `input_spec`.
"""
input_spec: Tuple[str]
rowwise_rule: Tuple[str]
colwise_rule: Tuple[str]
factor_sizes: Dict[str, int]
class ScalingModeMetadataImpl(ABC):
......@@ -40,7 +62,11 @@ class ScalingModeMetadataImpl(ABC):
@abstractmethod
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors.
......@@ -48,11 +74,26 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
@abstractmethod
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for delayed scaling mode.
......@@ -69,7 +110,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return jnp.float32
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in delayed scaling.
......@@ -77,6 +122,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors - (1,)
......@@ -84,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
del data_shape, is_colwise
return (1,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del flatten_axis
input_spec = tuple(f"x{i}" for i in range(input_rank))
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {})
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for block scaling mode.
......@@ -113,8 +176,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""
return jnp.float8_e8m0fnu
def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
"""Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
if len(data_shape) > 1:
# handle last dim
assert data_shape[-1] % scale_block_dim == 0
last = data_shape[-1] // scale_block_dim
scale_shape = (last,)
assert n_scale_blocks % last == 0
n_scale_blocks //= last
# handle middle dim, exclude first and last
for mid in reversed(data_shape[1:-1]):
scale_shape = (mid,) + scale_shape
assert n_scale_blocks % mid == 0
n_scale_blocks //= mid
scale_shape = (n_scale_blocks,) + scale_shape
else:
scale_shape = (n_scale_blocks,)
assert len(scale_shape) == len(
data_shape
), f"scale_shape {scale_shape}, data_shape {data_shape}"
return scale_shape
def get_scale_shape(
self, data_shape: Tuple[int, ...], is_colwise: bool = False, is_padded: bool = True
self,
data_shape: Tuple[int, ...],
is_colwise: bool = False,
is_padded: bool = True,
flatten_axis: int = -1,
) -> Tuple[int, ...]:
"""Get the shape for scale tensors in block scaling.
......@@ -122,6 +212,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
......@@ -135,38 +226,87 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment
seq_axis = len(data_shape) - 2
if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
assert (
data_shape[seq_axis] % block_x == 0
), f"Input data of shape {data_shape} should be padded by {block_x} in axes={seq_axis}"
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
assert data_shape[flatten_axis - 1] % block_x == 0, (
f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
f" {flatten_axis - 1}"
)
assert (
data_shape[-1] % block_y == 0
), f"Input data of shape {data_shape} should be padded by {block_y} in axis -1"
), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1
n_block_seq = data_shape[seq_axis] // block_x
n_block_y = data_shape[-1] // block_y
flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
n_flat_first_dim = reduce(operator.mul, data_shape[:seq_axis], 1) * n_block_seq
assert flattened_first_dim % block_x == 0, (
f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
f" {data_shape} - should be divisible by block_x {block_x}"
)
assert flattened_last_dim % block_y == 0, (
"Flattened last dim - mutiplication of"
f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
f" divisible by block_y {block_y}"
)
# Padding
n_flat_first_dim = ((n_flat_first_dim + alignment_x - 1) // alignment_x) * alignment_x
n_block_y = ((n_block_y + alignment_y - 1) // alignment_y) * alignment_y
n_block_x = int(flattened_first_dim / block_x)
n_block_y = int(flattened_last_dim / block_y)
out_shape = ()
for i in range(seq_axis):
d = data_shape[i]
out_shape += (d,)
assert n_flat_first_dim % d == 0
n_flat_first_dim //= d
# padding
n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x)
n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y)
out_shape += (n_flat_first_dim, n_block_y)
first_dim_scale_shape = self._apply_scale_shape_correction(
data_shape[:flatten_axis], n_block_x, block_x
)
last_dim_scale_shape = self._apply_scale_shape_correction(
data_shape[flatten_axis:], n_block_y, block_y
)
return out_shape
return (*first_dim_scale_shape, *last_dim_scale_shape)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
# (Phuong: Map the NVTEScalingMode value to the ScalingMode
Returns:
The Shardy rules for the scaling mode
"""
input_spec = [f"x{i}" for i in range(input_rank)]
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
rowwise_var = unique_var
colwise_var = f"{unique_var}_"
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var
colwise = input_spec.copy()
colwise[flatten_axis - 1] = colwise_var
# This implementation needs to be updated for different block dims.
assert self._block_dims == (1, 32)
return QuantizeShardyRules(
tuple(input_spec),
tuple(rowwise),
tuple(colwise),
{"block_size_rowwise": 32, "block_size_colwise": 32},
)
@dataclass(frozen=True)
......@@ -175,16 +315,14 @@ class ScalingMode(Enum):
"""Enumeration of tensor scaling modes with their corresponding metadata implementations.
This class defines the available scaling modes for tensor quantization:
- NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NVTE_INVALID_SCALING: Invalid scaling mode
- NVTE_NO_SCALING: No scaling applied
- DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NO_SCALING: No scaling applied
"""
NVTE_DELAYED_TENSOR_SCALING = 0
NVTE_MXFP8_1D_SCALING = 1
NVTE_INVALID_SCALING = 2
NVTE_NO_SCALING = 3
NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode.
......@@ -208,34 +346,54 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_scale_dtype()
def get_scale_shape_2x(self, data_shape, is_padded=True) -> Tuple[Tuple[int]]:
def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=False, is_padded=is_padded
data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis
)
colwise_scale_shape = self.get_scale_shape(
data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis
)
colwise_scale_shape = self.get_scale_shape(data_shape, is_colwise=True, is_padded=is_padded)
return (rowwise_scale_shape, colwise_scale_shape)
def get_scale_shape(self, data_shape, is_colwise, is_padded=True) -> Tuple[int]:
def get_scale_shape(
self, data_shape, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded)
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1
) -> Tuple[Tuple[str]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
Returns:
The Shardy rules for the scaling mode
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
def __eq__(self, other):
"""Compare this scaling mode with another.
......@@ -273,8 +431,8 @@ class ScalingMode(Enum):
SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.NVTE_DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.NVTE_MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR
ScalingMode.NVTE_NO_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(),
}
......@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer
......@@ -84,6 +84,17 @@ class ScaledTensor(ABC):
ValueError: If called on a tensor that doesn't support column-wise access
"""
@abstractmethod
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
@register_pytree_node_class
@dataclass
......@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor
data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
"""
data: jnp.ndarray
......@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: jnp.dtype
_dq_func: Callable
is_colwise: bool
layout: str
data_layout: str
flatten_axis: int = -1
def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization.
......@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
"""
flatten_axis = (
len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
)
assert (
0 < flatten_axis < len(self.data.shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}"
if self.data_layout == "T":
flatten_axis = self.data.ndim - flatten_axis
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis
)
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis
)
if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, (
......@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor):
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv)
aux_data = (self.scaling_mode, self.dq_dtype, self._dq_func, self.is_colwise, self.layout)
aux_data = (
self.scaling_mode,
self.dq_dtype,
self._dq_func,
self.is_colwise,
self.data_layout,
self.flatten_axis,
)
return (children, aux_data)
def dequantize(self):
......@@ -183,6 +214,45 @@ class ScaledTensor1x(ScaledTensor):
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
# axis_names were given for N layout, so needs to be transpose for T layout
if self.data_layout == "T":
assert self.flatten_axis > 0
flatten_axis = -self.flatten_axis
axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis])
else:
axis_names = logical_axis_names
data = with_sharding_constraint_by_logical_axes(self.data, axis_names)
if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# TODO(Phuong): Handle padding !?
scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
else:
scale_inv = self.scale_inv
return ScaledTensor1x(
data=data,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=self.dq_dtype,
_dq_func=self._dq_func,
is_colwise=self.is_colwise,
data_layout=self.data_layout,
flatten_axis=self.flatten_axis,
)
@register_pytree_node_class
@dataclass
......@@ -233,6 +303,27 @@ class ScaledTensor2x(ScaledTensor):
"""
return self.colwise_tensor
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if not logical_axis_names:
return self
rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
logical_axis_names
)
colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
logical_axis_names
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
@dataclass
class ScaledTensorFactory:
......@@ -244,7 +335,13 @@ class ScaledTensorFactory:
@staticmethod
def create_1x(
data, scale_inv, scaling_mode, dq_dtype=jnp.bfloat16, is_colwise=False, layout="N"
data,
scale_inv,
scaling_mode,
dq_dtype=jnp.bfloat16,
is_colwise=False,
data_layout="N",
flatten_axis=-1,
):
"""Creates a single-scale quantized tensor.
......@@ -254,13 +351,16 @@ class ScaledTensorFactory:
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N")
data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x instance
"""
dq_func = Dequantizer.funcs.get(scaling_mode)
return ScaledTensor1x(data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, layout)
return ScaledTensor1x(
data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis
)
@staticmethod
def create_2x(
......@@ -270,7 +370,8 @@ class ScaledTensorFactory:
colwise_scale_inv,
scaling_mode,
dq_dtype=jnp.bfloat16,
layout="NN",
data_layout="NN",
flatten_axis=-1,
):
"""Creates a double-scale quantized tensor.
......@@ -281,7 +382,8 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor2x instance
......@@ -294,7 +396,8 @@ class ScaledTensorFactory:
dq_dtype,
dq_func,
is_colwise=False,
layout=layout[0],
data_layout=data_layout[0],
flatten_axis=flatten_axis,
)
colwise_tensor = ScaledTensor1x(
colwise_data,
......@@ -303,7 +406,8 @@ class ScaledTensorFactory:
dq_dtype,
dq_func,
is_colwise=True,
layout=layout[1],
data_layout=data_layout[1],
flatten_axis=flatten_axis,
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
......@@ -315,8 +419,9 @@ class ScaledTensorFactory:
colwise_scale_inv: jnp.ndarray,
scaling_mode: ScalingMode,
dq_dtype: jnp.dtype = jnp.bfloat16,
layout: str = "NN",
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE,
data_layout: str = "NN",
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
flatten_axis: int = -1,
):
"""Creates a scaled tensor based on the quantization axis.
......@@ -327,13 +432,13 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
q_axis: The quantization axis (default: ROWWISE)
data_layout: The data_layout specification (default: "NN")
q_layout: The quantization axis (default: ROWWISE)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
"""
if q_axis == QuantizeAxis.ROWWISE_COLWISE:
if q_layout == QuantizeLayout.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x(
data,
scale_inv,
......@@ -341,12 +446,19 @@ class ScaledTensorFactory:
colwise_scale_inv,
scaling_mode,
dq_dtype,
layout=layout,
data_layout=data_layout,
flatten_axis=flatten_axis,
)
is_colwise = q_axis == QuantizeAxis.COLWISE
is_colwise = q_layout == QuantizeLayout.COLWISE
return ScaledTensorFactory.create_1x(
data, scale_inv, scaling_mode, dq_dtype, is_colwise=is_colwise, layout=layout[0]
data,
scale_inv,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
)
......@@ -360,24 +472,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns:
The tensor with applied sharding constraints
"""
if isinstance(x, ScaledTensor1x):
return ScaledTensor1x(
data=with_sharding_constraint_by_logical_axes(x.data, logical_axis_names),
scale_inv=x.scale_inv,
scaling_mode=x.scaling_mode,
dq_dtype=x.dq_dtype,
_dq_func=x._dq_func,
is_colwise=x.is_colwise,
layout=x.layout,
)
if isinstance(x, ScaledTensor2x):
return ScaledTensor2x(
rowwise_tensor=with_sharding_constraint_by_logical_axes(
x.rowwise_tensor, logical_axis_names
),
colwise_tensor=with_sharding_constraint_by_logical_axes(
x.colwise_tensor, logical_axis_names
),
)
if isinstance(x, ScaledTensor):
return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment