Unverified Commit 99df8810 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add logic for block-scaled tensors with GEMM swizzled scales (#2486)



* Add general C API for setting tensor params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Implement general accessors for NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor tex swizzling to skip if scales are already swizzled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support pre-swizzled scales in MXFP8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tex function to swizzle MXFP8 scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in inplace swizzle function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments to use "compact/swizzled format"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 quantize kernel with pre-swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose pre-swizzled scales in modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in multi-swizzle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 gated activations with swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Deprecate DSv3-specific quantization logic in C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove support for DSv3 compact data from quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove DSv3 compact data format from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in FP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX to use new swizzled scale API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update C++ swizzle test with swizzled scales API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Return default tensor params when querying params for invalid NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug DSv3 FP8 test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Userbuffers test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure gated activations populate FP8 transpose if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable pre-swizzling with debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts and review suggestions

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use explicitly sized types in config accessors

Miscellaneous review suggestions from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make util header for function that compute swizzled scale index
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from @greptile-apps
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update expected error message in FP8 block-scaling test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @yaox12
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a652730f
......@@ -120,6 +120,7 @@ class Quantizer {
bool rowwise_usage = true;
bool columnwise_usage = true;
bool internal = false;
bool optimize_for_gemm = false;
py::handle quantizer;
protected:
......@@ -231,8 +232,6 @@ class Float8BlockQuantizer : public Quantizer {
bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0;
// Whether quantized tensor will be used in an all-gather
bool all_gather_usage = false;
private:
int block_scaling_dim = 2;
......@@ -358,11 +357,12 @@ inline size_t typeToNumBits(transformer_engine::DType t) {
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
case transformer_engine::DType::kFloat8E8M0:
return 8;
case transformer_engine::DType::kFloat4E2M1:
return 4;
default:
NVTE_ERROR("Invalid type");
NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
}
}
......@@ -386,8 +386,10 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) {
return at::kFloat8_e4m3fn;
case transformer_engine::DType::kFloat8E5M2:
return at::kFloat8_e5m2;
case transformer_engine::DType::kFloat8E8M0:
return at::kByte; // e8m0 dtype requires PyTorch 2.7.0+
default:
NVTE_ERROR("Invalid type");
NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
}
}
......@@ -414,8 +416,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
case torch::kInt64:
return transformer_engine::DType::kInt64;
default:
std::cout << "Type: " << static_cast<int>(t) << std::endl;
NVTE_ERROR("Invalid type");
NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
}
}
......@@ -477,7 +478,9 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);
std::vector<size_t> convertShape(const NVTEShape& shape);
size_t roundup(const size_t value, const size_t multiple);
size_t roundup(size_t value, size_t multiple);
size_t ceildiv(size_t numer, size_t denom);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
......
......@@ -7,7 +7,12 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#include <map>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "common.h"
......@@ -78,11 +83,6 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data);
std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -475,6 +475,13 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> unpadded_input_row_list);
/***************************************************************************************************
* Scale swizzling for GEMM
**************************************************************************************************/
void inplace_swizzle_scale_for_gemm(py::handle &tensor);
/***************************************************************************************************
* NVSHMEM APIs
**************************************************************************************************/
......
......@@ -327,9 +327,9 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
(columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none());
// Construct Python tensor
tensor_py_list.emplace_back(Float8BlockwiseQTensorClass(
rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype,
quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY));
tensor_py_list.emplace_back(
Float8BlockwiseQTensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale,
fp8_dtype, quantizer_py_list[i], is_2D_scaled));
// Construct C++ tensor
tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
......@@ -365,6 +365,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto fp8_dtype = quantizer_cpp_list[0]->dtype;
const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm;
constexpr size_t fp8_elem_size = 1;
constexpr size_t scale_elem_size = 1;
......@@ -475,8 +477,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// Construct Python tensor
tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data,
columnwise_scale, fp8_dtype,
quantizer_py_list[i]));
columnwise_scale, fp8_dtype, quantizer_py_list[i],
with_gemm_swizzled_scales));
// Construct C++ tensor
tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
......@@ -488,6 +490,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode));
tensor_cpp_list.back().set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
}
return retval;
......@@ -517,6 +520,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto fp4_dtype = quantizer_cpp_list[0]->dtype;
const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm;
constexpr size_t scale_elem_size = 1;
// Helper function to construct tensor view
......@@ -675,9 +679,9 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none();
// Construct Python tensor
tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data,
columnwise_scale, amax_rowwise, amax_columnwise,
fp4_dtype, quantizer_py_list[i]));
tensor_py_list.emplace_back(NVFP4TensorClass(
rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise,
amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales));
// Construct C++ tensor
// Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor,
......@@ -693,6 +697,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode);
tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
// Set the amax rowwise and amax columnwise if available
if (rowwise_usage) {
......@@ -703,6 +708,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
tensor_cpp_list.emplace_back(std::move(tensor_wrapper));
}
}
......
......@@ -240,9 +240,12 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
auto [A_row_scales, A_col_scales] = swizzle_scales_for_gemm(A_tensor, transa, !transa);
auto [B_row_scales, B_col_scales] = swizzle_scales_for_gemm(B_tensor, !transb, transb);
swizzled_scale_inverses_list.emplace_back(std::move(A_row_scales));
swizzled_scale_inverses_list.emplace_back(std::move(A_col_scales));
swizzled_scale_inverses_list.emplace_back(std::move(B_row_scales));
swizzled_scale_inverses_list.emplace_back(std::move(B_col_scales));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt
......@@ -501,9 +504,9 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa));
multi_tensor_swizzle_scales_for_gemm(te_A_wrappers, transa, !transa));
swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb));
multi_tensor_swizzle_scales_for_gemm(te_B_wrappers, !transb, transb));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt
......
......@@ -89,14 +89,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py);
TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor
// Quantizer
auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_nvte;
if (out.is_none()) {
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Choose implementation
enum class Impl {
......@@ -135,6 +129,19 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
}
// Output tensor
TensorWrapper out_nvte;
if (out.is_none()) {
if (impl == Impl::FULLY_FUSED) {
// FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN
// kernel does not support GEMM swizzled scales
quantizer_cpp->optimize_for_gemm = false;
}
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte;
py::object unquantized_out;
......@@ -318,14 +325,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor
// Quantizer
auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_nvte;
if (out.is_none()) {
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Choose implementation
enum class Impl {
......@@ -364,6 +365,19 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
}
// Output tensor
TensorWrapper out_nvte;
if (out.is_none()) {
if (impl == Impl::FULLY_FUSED) {
// FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN
// kernel does not support GEMM swizzled scales
quantizer_cpp->optimize_for_gemm = false;
}
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte;
py::object unquantized_out;
......
......@@ -290,6 +290,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding,
"Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>());
m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm,
"Convert tensor block scales into GEMM swizzled format");
// attention kernels
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "common.h"
#include "common/common.h"
#include "extensions.h"
#include "pybind.h"
#include "util.h"
namespace transformer_engine {
namespace pytorch {
namespace {
void reset_tensor_data(transformer_engine::TensorWrapper &tensor, bool rowwise, bool columnwise) {
NVTEShape shape;
shape.ndim = 1;
shape.data[0] = 0;
const transformer_engine::DType dtype = transformer_engine::DType::kFloat32;
if (rowwise) {
tensor.set_rowwise_data(nullptr, dtype, shape);
tensor.set_rowwise_scale_inv(nullptr, dtype, shape);
}
if (columnwise) {
tensor.set_columnwise_data(nullptr, dtype, shape);
tensor.set_columnwise_scale_inv(nullptr, dtype, shape);
}
}
} // namespace
std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm(
transformer_engine::TensorWrapper &tensor, bool rowwise_usage, bool columnwise_usage) {
// Return early if scale swizzling is not required
const auto scaling_mode = tensor.scaling_mode();
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return {std::nullopt, std::nullopt};
}
// Return early if scales are already swizzled
if (tensor.get_with_gemm_swizzled_scales()) {
return {std::nullopt, std::nullopt};
}
// CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();
// Swizzle row-wise scales if needed
std::optional<at::Tensor> rowwise_scales_pyt;
if (rowwise_usage) {
// Buffer for unswizzled scales
const auto input_scales_nvte = tensor.get_rowwise_scale_inv();
void *input_scales_dptr = input_scales_nvte.data_ptr;
const NVTEShape input_scales_shape = input_scales_nvte.shape;
const auto scales_dtype = static_cast<DType>(input_scales_nvte.dtype);
// Allocate buffer for swizzled scales
const NVTEShape output_scales_shape = input_scales_shape;
rowwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false);
void *output_scales_dptr = getDataPtr(*rowwise_scales_pyt);
// Initialize TE tensors with scales
const auto data_nvte = tensor.get_rowwise_data();
const auto data_dtype = static_cast<DType>(data_nvte.dtype);
TensorWrapper input_nvte(scaling_mode);
input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_rowwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape);
TensorWrapper output_nvte(scaling_mode);
output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
output_nvte.set_with_gemm_swizzled_scales(true);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE(
{ nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); });
// Update tensor with swizzled scales
tensor.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
}
// Swizzle column-wise scales if needed
std::optional<at::Tensor> columnwise_scales_pyt;
if (columnwise_usage) {
// Buffer for unswizzled scales
const auto input_scales_nvte = tensor.get_columnwise_scale_inv();
void *input_scales_dptr = input_scales_nvte.data_ptr;
const NVTEShape input_scales_shape = input_scales_nvte.shape;
const auto scales_dtype = static_cast<DType>(input_scales_nvte.dtype);
// Allocate buffer for swizzled scales
const NVTEShape output_scales_shape = input_scales_shape;
columnwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false);
void *output_scales_dptr = getDataPtr(*columnwise_scales_pyt);
// Initialize TE tensors with scales
const auto data_nvte = tensor.get_columnwise_data();
const auto data_dtype = static_cast<DType>(data_nvte.dtype);
TensorWrapper input_nvte(scaling_mode);
input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_columnwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape);
TensorWrapper output_nvte(scaling_mode);
output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
output_nvte.set_with_gemm_swizzled_scales(true);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE(
{ nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); });
// Update tensor with swizzled scales
tensor.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
}
// Update tensor
reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage);
tensor.set_with_gemm_swizzled_scales(true);
return {std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)};
}
std::optional<at::Tensor> multi_tensor_swizzle_scales_for_gemm(
std::vector<transformer_engine::TensorWrapper> &tensors, bool rowwise_usage,
bool columnwise_usage) {
// Checks and trivial cases
NVTE_CHECK(rowwise_usage != columnwise_usage,
"Expect exactly one of rowwise_usage=", rowwise_usage,
" and columnwise_usage=", columnwise_usage, ".");
if (tensors.empty()) {
return std::nullopt;
}
const auto scaling_mode = tensors.front().scaling_mode();
for (const auto &tensor : tensors) {
NVTE_CHECK(tensor.scaling_mode() == scaling_mode, "Tensors have different scaling modes");
}
// Return early if scale swizzling is not required
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return std::nullopt;
}
// Filter out tensors that already have swizzled scales
std::vector<TensorWrapper *> tensors_needing_swizzle;
for (auto &tensor : tensors) {
if (!tensor.get_with_gemm_swizzled_scales()) {
tensors_needing_swizzle.push_back(&tensor);
}
}
if (tensors_needing_swizzle.empty()) {
return std::nullopt;
}
// Determine buffer size needed for swizzled scales
std::vector<size_t> output_scales_offsets;
size_t output_scales_bytes = 0;
for (auto &tensor : tensors_needing_swizzle) {
const auto scales_nvte =
(rowwise_usage ? tensor->get_rowwise_scale_inv() : tensor->get_columnwise_scale_inv());
const auto &shape = scales_nvte.shape;
const auto dtype = static_cast<DType>(scales_nvte.dtype);
const auto dtype_bits = transformer_engine::pytorch::typeToNumBits(dtype);
const auto size = product(shape, 0, shape.ndim);
output_scales_bytes = roundup(output_scales_bytes, 16); // align to 16B
output_scales_offsets.push_back(output_scales_bytes);
output_scales_bytes += ceildiv(size * dtype_bits, 8);
}
// Allocate buffer for swizzled scales
auto output_scales_pyt = allocateSpace(std::vector<size_t>{output_scales_bytes},
transformer_engine::DType::kByte, false);
uint8_t *output_scales_dptr = reinterpret_cast<uint8_t *>(getDataPtr(output_scales_pyt));
// Construct TE tensors with only scales
std::vector<transformer_engine::TensorWrapper> inputs_nvte, outputs_nvte;
for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) {
auto &tensor = *tensors_needing_swizzle[i];
inputs_nvte.emplace_back(scaling_mode);
outputs_nvte.emplace_back(scaling_mode);
auto &input_nvte = inputs_nvte.back();
auto &output_nvte = outputs_nvte.back();
output_nvte.set_with_gemm_swizzled_scales(true);
if (rowwise_usage) {
const auto data_nvte = tensor.get_rowwise_data();
const auto scales_nvte = tensor.get_rowwise_scale_inv();
const auto data_dtype = static_cast<transformer_engine::DType>(data_nvte.dtype);
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_rowwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape);
output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
} else {
const auto data_nvte = tensor.get_columnwise_data();
const auto scales_nvte = tensor.get_columnwise_scale_inv();
const auto data_dtype = static_cast<transformer_engine::DType>(data_nvte.dtype);
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_columnwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape);
output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i],
scales_dtype, scales_nvte.shape);
}
}
// Pack raw NVTETensors into vectors
std::vector<NVTETensor> inputs_nvte_raw, outputs_nvte_raw;
for (auto &tensor : inputs_nvte) {
inputs_nvte_raw.emplace_back(tensor.data());
}
for (auto &tensor : outputs_nvte) {
outputs_nvte_raw.emplace_back(tensor.data());
}
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(),
inputs_nvte_raw.size(),
at::cuda::getCurrentCUDAStream());
});
// Update tensors with swizzled scales
for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) {
auto &tensor = *tensors_needing_swizzle[i];
reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage);
tensor.set_with_gemm_swizzled_scales(true);
if (rowwise_usage) {
auto scales_nvte = outputs_nvte[i].get_rowwise_scale_inv();
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
tensor.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
} else {
auto scales_nvte = outputs_nvte[i].get_columnwise_scale_inv();
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
tensor.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
}
}
return std::move(output_scales_pyt);
}
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input,
bool rowwise) {
// Check input tensor
const NVTEScalingMode scaling_mode = input.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (size_t i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (size_t i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
data_shape.ndim = 2;
// Recreate input tensor with rowwise usage
transformer_engine::TensorWrapper input_cu(scaling_mode);
input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
const NVTEBasicTensor scale_inv =
rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
input_cu.set_rowwise_scale_inv(
scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);
// Create output tensor
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
// Output swizzled mxfp8 scaling factor dimensions
const size_t swizzled_scale_inv_first_dim = ceildiv(data_flat_first_dim, 128) * 128;
const size_t swizzled_scale_inv_last_dim = ceildiv(data_flat_last_dim, 128) * 4;
// Allocate memory for swizzled mxfp8 scaling factors
at::Tensor swizzled_scale_inv =
allocateSpace(std::vector<size_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim},
transformer_engine::DType::kByte, false);
// Set rowwise scaling factors on output
void *const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{};
swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
swizzled_scale_inv_shape.ndim = 2;
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
swizzled_scale_inv_shape);
output_cu.set_with_gemm_swizzled_scales(true);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
NVTE_SCOPED_GIL_RELEASE({
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input = std::move(output_cu);
return swizzled_scale_inv;
}
void inplace_swizzle_scale_for_gemm(py::handle &tensor) {
// Convert Python tensor to C++ tensor
auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none());
// Return early if scale swizzling is not required
const auto scaling_mode = tensor_nvte.scaling_mode();
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return;
}
// Return early if scales are already swizzled
if (tensor_nvte.get_with_gemm_swizzled_scales()) {
return;
}
// Check what scaling factors the tensor contains
auto is_empty = [](const NVTEBasicTensor &t) -> bool {
return t.shape.ndim == 1 && t.shape.data[0] == 0;
};
const bool has_rowwise_scales = !is_empty(tensor_nvte.get_rowwise_scale_inv());
const bool has_columnwise_scales = !is_empty(tensor_nvte.get_columnwise_scale_inv());
// Swizzle scaling factors
auto [rowwise_scales, columnwise_scales] =
swizzle_scales_for_gemm(tensor_nvte, has_rowwise_scales, has_columnwise_scales);
// Update Python tensor with swizzled scales
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
if (has_rowwise_scales) {
tensor.attr("_rowwise_scale_inv") = rowwise_scales;
}
if (has_columnwise_scales) {
tensor.attr("_columnwise_scale_inv") = columnwise_scales;
}
tensor.attr("_with_gemm_swizzled_scales") = true;
break;
case NVTE_NVFP4_1D_SCALING:
if (has_rowwise_scales) {
tensor.attr("_rowwise_scale_inv") = rowwise_scales;
}
if (has_columnwise_scales) {
tensor.attr("_columnwise_scale_inv") = columnwise_scales;
}
tensor.attr("_with_gemm_swizzled_scales") = true;
break;
default:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
}
}
} // namespace pytorch
} // namespace transformer_engine
......@@ -52,10 +52,12 @@ Quantizer::Quantizer(const py::handle& quantizer) {
this->rowwise_usage = true;
this->columnwise_usage = true;
this->internal = false;
this->optimize_for_gemm = false;
} else {
this->rowwise_usage = quantizer.attr("rowwise_usage").cast<bool>();
this->columnwise_usage = quantizer.attr("columnwise_usage").cast<bool>();
this->internal = quantizer.attr("internal").cast<bool>();
this->optimize_for_gemm = quantizer.attr("optimize_for_gemm").cast<bool>();
this->quantizer = quantizer;
}
}
......@@ -555,7 +557,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim.");
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {}
......@@ -575,10 +576,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) {
data_rowwise = at::empty(torch_shape, opts);
auto scale_shape = get_scale_shape(shape, false);
......@@ -597,21 +594,13 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
}
auto scale_shape = get_scale_shape(shape, true);
......@@ -635,7 +624,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format);
"is_2D_scaled"_a = (block_scaling_dim == 2));
} else {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
......@@ -643,8 +632,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2),
"data_format"_a = data_format);
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
}
return {std::move(tensor), std::move(ret)};
......@@ -654,6 +642,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
py::object tensor) const {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
const bool with_gemm_swizzled_scales = true;
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
......@@ -675,13 +664,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t> {
auto get_columnwise_shape = [&columnwise_data]() -> std::vector<size_t> {
if (!columnwise_data) {
return std::vector<size_t>();
}
if (all_gather_usage) {
return getTensorShape(*columnwise_data);
}
std::vector<size_t> shape = getTensorShape(*columnwise_data);
std::vector<size_t> shape_transposed(shape.size());
for (size_t i = 0; i + 1 < shape.size(); ++i) {
......@@ -696,12 +682,12 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
if (rowwise_data) {
shape = getTensorShape(*rowwise_data);
if (columnwise_data) {
auto expected_shape = get_columnwise_shape(all_gather_usage);
auto expected_shape = get_columnwise_shape();
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
") and column-wise data (shape=", expected_shape, ") do not match");
}
} else {
shape = get_columnwise_shape(all_gather_usage);
shape = get_columnwise_shape();
}
std::vector<int64_t> torch_shape;
for (auto s : shape) {
......@@ -738,21 +724,13 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
std::vector<size_t> columnwise_shape;
std::vector<int64_t> torch_columnwise_shape;
if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
}
if (!columnwise_data) {
......@@ -798,6 +776,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
}
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
set_quantization_params(&ret);
return {std::move(ret), std::move(tensor)};
}
......@@ -813,9 +792,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o
}
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
if (all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
......@@ -832,10 +808,6 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise;
......@@ -845,26 +817,17 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
sinv0 = ceildiv(m_dim, kBlockLen);
sinv1 = roundup(ceildiv(k_dim, kBlockLen), 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
sinv0 = ceildiv(k_dim, kBlockLen);
sinv1 = roundup(m_dim, 4);
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_shape = {sinv0, sinv1};
} else {
......@@ -872,24 +835,16 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
sinv0 = ceildiv(k_dim, kBlockLen);
sinv1 = roundup(ceildiv(m_dim, kBlockLen), 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
sinv0 = ceildiv(m_dim, kBlockLen);
sinv1 = roundup(k_dim, 4);
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_shape = {sinv0, sinv1};
}
......@@ -906,6 +861,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
DType dtype) const {
using namespace pybind11::literals;
// Scaling factor format
const bool with_gemm_swizzled_scales = this->optimize_for_gemm;
// Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1;
......@@ -951,19 +909,17 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
py::object out_py;
if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass));
out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
out_py = MXFP8TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py,
columnwise_scale_inv_py, this->dtype, this->quantizer,
with_gemm_swizzled_scales);
} else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
out_py = MXFP8TensorClass(
"shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales);
}
// Construct C++ MXFP8 tensor
......@@ -978,6 +934,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
columnwise_scale_inv_shape);
}
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
......@@ -987,6 +944,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor.");
// Scaling factor format
const bool with_gemm_swizzled_scales = this->optimize_for_gemm;
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
......@@ -1070,6 +1030,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
// Coerce other attrs
tensor.attr("_fp8_dtype") = dtype;
tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales;
// Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
......@@ -1083,6 +1044,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*columnwise_scale_inv));
}
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
......@@ -1173,6 +1135,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
DType dtype) const {
using namespace pybind11::literals;
// Scaling factor format
const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm
// Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1;
......@@ -1235,12 +1200,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
py::object out_py;
if (internal) {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass));
out_py = NVFP4TensorClass(
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
out_py = NVFP4TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py,
columnwise_scale_inv_py, amax_rowwise_py, amax_columnwise_py,
this->dtype, this->quantizer, with_gemm_swizzled_scales);
} else {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass));
out_py = NVFP4TensorClass(
......@@ -1249,7 +1211,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
"quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales);
}
// Construct C++ tensor
......@@ -1272,6 +1234,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
......@@ -1301,6 +1264,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor.");
// Scaling factor format
const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
......@@ -1438,6 +1404,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
......
......@@ -55,8 +55,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer
TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) {
auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast<bool>();
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor.");
......@@ -78,6 +79,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
getTensorShape(scale_inv));
}
// Scale layout
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
// Quantizer state
quantizer->set_quantization_params(&ret);
......@@ -93,6 +97,7 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
// Row-wise data
if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
......@@ -102,6 +107,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
}
// Column-wise data
if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
......@@ -112,7 +119,10 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
}
// Quantizer state
quantizer->set_quantization_params(&ret);
return ret;
}
......@@ -121,8 +131,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast<bool>();
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor.");
......@@ -150,6 +161,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
getTensorShape(amax_columnwise));
}
// Scale layout
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
// Quantizer state
quantizer->set_quantization_params(&ret);
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "util.h"
#include "common.h"
#include "common/common.h"
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING &&
input.scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt;
}
NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8,
"4-bit or 8-bit input required for swizzling scaling factors.");
const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING;
NVTEBasicTensor scale_inv;
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = input.shape();
scale_inv = input.get_rowwise_scale_inv();
} else {
nvte_input_shape = input.get_columnwise_data().shape;
scale_inv = input.get_columnwise_scale_inv();
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape.");
// Allocate memory for swizzled output.
auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
std::vector<int64_t> scale_inv_shape_int;
for (size_t i = 0; i < scale_inv_shape.size(); ++i) {
scale_inv_shape_int.push_back(static_cast<int64_t>(scale_inv_shape[i]));
}
auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options);
void* scale_inv_dptr = scale_inv.data_ptr;
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
transformer_engine::TensorWrapper input_cu(input.scaling_mode());
transformer_engine::TensorWrapper output_cu(input.scaling_mode());
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
} else {
input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
}
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
}
return swizzled_scale_inv;
}
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper>& tensors, bool rowwise) {
using namespace transformer_engine::pytorch;
if (tensors.empty()) {
return std::nullopt;
}
bool all_same_scaling_mode = std::all_of(
tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) {
return val.scaling_mode() == tensors.front().scaling_mode();
});
NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same.");
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING &&
tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt;
}
const auto scaling_mode = tensors.front().scaling_mode();
const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors;
// Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
std::vector<std::vector<size_t>> scale_inv_shapes;
std::vector<void*> scale_inv_dptrs;
size_t buffer_size = 0;
std::vector<size_t> scale_inv_offsets;
constexpr size_t scale_elem_size = 1;
for (auto& tensor : tensors) {
NVTEBasicTensor scale_inv;
if (rowwise) {
scale_inv = tensor.get_rowwise_scale_inv();
} else {
scale_inv = tensor.get_columnwise_scale_inv();
}
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_inv_offsets.push_back(buffer_size);
buffer_size += product(scale_inv_shape) * scale_elem_size;
scale_inv_shapes.emplace_back(scale_inv_shape);
scale_inv_dptrs.push_back(scale_inv.data_ptr);
}
// Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
// Empty tensors don't require scale swizzling
if (tensor.numel() == 0) {
continue;
}
// Tensor shape
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = tensor.shape();
} else {
nvte_input_shape = tensor.get_columnwise_data().shape;
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(scaling_mode);
transformer_engine::TensorWrapper output_cu(scaling_mode);
if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
} else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
}
input_tensors.emplace_back(input_cu.data());
output_tensors.emplace_back(output_cu.data());
wrappers.emplace_back(std::move(input_cu));
wrappers.emplace_back(std::move(output_cu));
}
// Launch kernel
nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(),
input_tensors.size(), at::cuda::getCurrentCUDAStream());
return buffer;
}
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
using transformer_engine::DIVUP;
// Check input tensor
const NVTEScalingMode scaling_mode = input.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (size_t i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (size_t i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
data_shape.ndim = 2;
// Recreate input tensor with rowwise usage
transformer_engine::TensorWrapper input_cu(scaling_mode);
input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
const NVTEBasicTensor scale_inv =
rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
input_cu.set_rowwise_scale_inv(
scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);
// Create output tensor
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
// Output swizzled mxfp8 scaling factor dimensions
const size_t swizzled_scale_inv_first_dim = DIVUP<size_t>(data_flat_first_dim, 128) * 128;
const size_t swizzled_scale_inv_last_dim = DIVUP<size_t>(data_flat_last_dim, 128) * 4;
// Allocate memory for swizzled mxfp8 scaling factors
const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
at::Tensor swizzled_scale_inv = at::empty(
std::vector<int64_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options);
// Set rowwise scaling factors on output
void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{};
swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
swizzled_scale_inv_shape.ndim = 2;
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
swizzled_scale_inv_shape);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input = std::move(output_cu);
return swizzled_scale_inv;
}
......@@ -10,33 +10,44 @@
#include <torch/extension.h>
#include <optional>
#include <tuple>
#include <vector>
#include "transformer_engine/transformer_engine.h"
/*! \brief Swizzle the scaling factor of the input tensor.
namespace transformer_engine {
namespace pytorch {
/*! \brief Convert tensor block scales into GEMM swizzled format.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
* The returned swizzled scales should be kept alive during the GEMM.
*/
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
bool rowwise);
std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm(
TensorWrapper& tensor, bool rowwise_usage, bool columnwise_usage);
/*! \brief Swizzle the scaling factor of the input tensors.
/*! \brief Convert multiple tensor block scales into GEMM swizzled format.
*
* The returned swizzled scaling factor tensors should be kept alive during the GEMMs.
* The returned swizzled scales should be kept alive during the GEMMs.
*/
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise);
std::optional<at::Tensor> multi_tensor_swizzle_scales_for_gemm(std::vector<TensorWrapper>& tensors,
bool rowwise_usage,
bool columnwise_usage);
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
*
* If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid
* transposing it in memory. Due to differences in how block scaling and mxfp8 store data,
* this requires the calling code to treat the output tensor as having been tranposed in this case.
* If rowwise==false, the columnwise data will be reinterpreted as
* rowwise data to avoid transposing it in memory. Due to differences
* in how block scaling and mxfp8 store data, this requires the
* calling code to treat the output tensor as having been transposed
* in this case.
*
* Returns the swizzled scaling factor of the converted mxfp8 tensor.
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
* Returns the swizzled scaling factor of the converted mxfp8 tensor.
* The returned swizzled scaling factor tensor should be kept alive
* during the GEMM.
*/
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input,
bool rowwise);
at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise);
} // namespace pytorch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
......@@ -48,7 +48,7 @@ from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
......@@ -930,6 +930,34 @@ def reduce_scatter_along_first_dim(
return output, handle
@dataclass
class _AsyncHandle:
"""Handle for asynchronous collectives."""
async_handle: torch.distributed.Work
post_process_function: Optional[Callable] = None
post_process_function_args: Optional[Tuple[Any, ...]] = None
post_process_function_kwargs: Optional[Dict[str, Any]] = None
_synchronized: bool = False
def wait(self) -> None:
"""Synchronize the asynchronous communicaton.
Perform post-processing if needed.
"""
if self._synchronized:
return
self.async_handle.wait()
if self.post_process_function is not None:
args = self.post_process_function_args
args = () if args is None else args
kwargs = self.post_process_function_kwargs
kwargs = {} if kwargs is None else kwargs
self.post_process_function(*args, **kwargs)
self._synchronized = True
def _all_gather_fp8(
inp: torch.Tensor,
process_group: dist_group_type,
......@@ -1020,73 +1048,7 @@ def _all_gather_fp8(
return out, handle
def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]:
"""Get quantizer format."""
if isinstance(quantizer, DebugQuantizer):
quantizer = quantizer.parent_quantizer
if isinstance(quantizer, Float8BlockQuantizer):
return quantizer.all_gather_usage
return None
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
if isinstance(quantizer, DebugQuantizer):
_quantizer = quantizer.parent_quantizer
if isinstance(_quantizer, Float8BlockQuantizer):
_quantizer.all_gather_usage = compact
def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorStorage,
quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorStorage:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
if out._is_gemm_ready_format():
return out
needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage
need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if needs_columnwise_data_transpose:
out._transpose_columnwise_data()
if need_rowwise_scale_transpose:
out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous()
out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY
return out
@dataclass
class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorStorage
quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_fp8_blockwise_gather(self.tensor, self.quantizer)
self._synchronized = True
def _all_gather_fp8_blockwise(
def _start_all_gather_fp8_blockwise(
inp: torch.Tensor,
process_group: dist_group_type,
*,
......@@ -1125,44 +1087,25 @@ def _all_gather_fp8_blockwise(
)
world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")
# Output tensor dims
if out_shape is None:
out_shape = list(inp.size())
out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler
if (
not isinstance(inp, Float8BlockwiseQTensorStorage)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty(
out_shape,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
# Check that quantizer is valid
if quantizer is None:
raise ValueError("Quantizer is missing")
if not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
# Fall back to high-precision all-gather if FP8 is not supported
if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
out = torch.empty(out_shape, dtype=dtype, device=device)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = False
out = quantizer(out)
quantizer.all_gather_usage = orig_all_gather_usage
return out, None
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# Cast input tensor to Float8BlockwiseQTensor with required data
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True
# Quantize input tensor if needed
if not isinstance(inp, Float8BlockwiseQTensorStorage):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
......@@ -1177,14 +1120,9 @@ def _all_gather_fp8_blockwise(
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.all_gather_usage = orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT:
raise RuntimeError(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f"but found data_format={inp._data_format}"
)
# Temporary buffers for all-gathering transposed buffers
interleaved_rowwise_scale_inv = None
interleaved_columnwise_data = None
# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
......@@ -1193,11 +1131,17 @@ def _all_gather_fp8_blockwise(
async_ops=async_op,
) as coalescing_manager:
# Gather Float8BlockwiseQTensor data for row-wise usage
# Gather row-wise data
if quantizer.rowwise_usage:
# Launch all-gathers
scale_inv_shape = list(inp._rowwise_scale_inv.size())
scale_inv_shape[0] *= world_size
interleaved_rowwise_scale_inv = torch.empty(
scale_inv_shape,
dtype=inp._rowwise_scale_inv.dtype,
device=device,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_scale_inv,
interleaved_rowwise_scale_inv,
inp._rowwise_scale_inv,
group=process_group,
)
......@@ -1207,36 +1151,73 @@ def _all_gather_fp8_blockwise(
group=process_group,
)
# Gather Float8BlockwiseQTensor data for column-wise usage
# Column-wise data
if quantizer.columnwise_usage:
# Launch all-gathers
data_shape = list(inp._columnwise_data.size())
data_shape[0] *= world_size
interleaved_columnwise_data = torch.empty(
data_shape,
dtype=inp._columnwise_data.dtype,
device=device,
)
torch.distributed.all_gather_into_tensor(
out._columnwise_scale_inv,
inp._columnwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._columnwise_data,
interleaved_columnwise_data,
inp._columnwise_data,
group=process_group,
)
handle = coalescing_manager if async_op else None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
# Finalize communication if needed
async_handle = None
if async_op:
handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle)
async_handle = _AsyncHandle(
coalescing_manager,
post_process_function=_finish_all_gather_fp8_blockwise,
post_process_function_args=(
out,
world_size,
interleaved_rowwise_scale_inv,
interleaved_columnwise_data,
),
)
else:
# if it's a sync op, we need to do the transpose here as post processing step
_post_process_fp8_blockwise_gather(out, quantizer, handle)
_finish_all_gather_fp8_blockwise(
out,
world_size,
interleaved_rowwise_scale_inv,
interleaved_columnwise_data,
)
return out, handle
return out, async_handle
def _finish_all_gather_fp8_blockwise(
out: Float8BlockwiseQTensorStorage,
world_size: int,
interleaved_rowwise_scale_inv: Optional[torch.Tensor],
interleaved_columnwise_data: Optional[torch.Tensor],
) -> Float8BlockwiseQTensorStorage:
"""Post-process FP8 blockwise gather."""
# Fix interleaving in row-wise scales
if interleaved_rowwise_scale_inv is not None:
dim0 = out._rowwise_scale_inv.size(0)
view_in = interleaved_rowwise_scale_inv.view(world_size, dim0, -1)
view_out = out._rowwise_scale_inv.view(dim0, world_size, -1)
tex.swap_first_dims(view_in, out=view_out)
# Fix interleaving in column-wise data
if interleaved_columnwise_data is not None:
dim0 = out._columnwise_data.size(0)
view_in = interleaved_columnwise_data.view(world_size, dim0, -1)
view_out = out._columnwise_data.view(dim0, world_size, -1)
tex.swap_first_dims(view_in, out=view_out)
return out
def _swap_first_dims(tensor: torch.Tensor, world_size: int):
......@@ -1250,7 +1231,7 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int):
"""
shape = tensor.shape
assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave."
assert len(shape) >= 2, "Wrong number of dimensions for fixing interleave."
first_dim = shape[0]
flattened_trailing = math.prod(shape[1:])
assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave."
......@@ -1681,7 +1662,7 @@ def gather_along_first_dim(
if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance(
quantizer, Float8BlockQuantizer
):
return _all_gather_fp8_blockwise(
return _start_all_gather_fp8_blockwise(
inp,
process_group,
async_op=async_op,
......@@ -1719,10 +1700,6 @@ def gather_along_first_dim(
)
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
compact = _get_quantizer_format(quantizer)
_set_quantizer_format(quantizer, compact=False)
out = torch.empty(
out_shape,
dtype=inp.dtype,
......@@ -1731,7 +1708,6 @@ def gather_along_first_dim(
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
_set_quantizer_format(quantizer, compact=compact)
return out, None
# Dequantize quantized tensor if not supported
......
......@@ -560,6 +560,8 @@ def fill_userbuffers_buffer_for_all_gather(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}"
)
if local_tensor._with_gemm_swizzled_scales:
raise ValueError("Userbuffers assumes MXFP8 tensors have unswizzled scales")
local_scale_inv = (
local_tensor._rowwise_scale_inv
if with_rowwise_data
......@@ -592,6 +594,7 @@ def fill_userbuffers_buffer_for_all_gather(
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
with_gemm_swizzled_scales=False,
)
return global_tensor, local_tensor
......
......@@ -720,13 +720,9 @@ class GroupedLinear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
# Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
assert not self.tp_size > 1, (
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
self._customize_quantizers_float8_current_scaling(fwd, recipe)
def reset_parameters(self, defer_init=False):
......@@ -879,9 +875,12 @@ class GroupedLinear(TransformerEngineBaseModule):
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
recipe.float8_current_scaling()
), "current scaling recipe quantizer customization here"
assert not self.tp_size > 1, (
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
if fwd:
for i in range(self.num_gemms):
# set configs about amax epsilon and power_2_scale
......@@ -954,9 +953,9 @@ class GroupedLinear(TransformerEngineBaseModule):
]
for i in range(self.num_gemms)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = False
input_quantizers[i].internal = True
input_quantizers[i].optimize_for_gemm = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
......@@ -966,6 +965,7 @@ class GroupedLinear(TransformerEngineBaseModule):
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
grad_output_quantizers[i].optimize_for_gemm = True
return (
input_quantizers,
weight_quantizers,
......
......@@ -64,7 +64,6 @@ from ..quantized_tensor import (
restore_from_saved,
)
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..cpu_offload import (
is_cpu_offload_enabled,
......@@ -253,8 +252,6 @@ class _LayerNormLinear(torch.autograd.Function):
if fp8 or debug:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, Float8BlockQuantizer):
input_quantizer.all_gather_usage = False
ln_out_total = input_quantizer(ln_out_total)
else:
quantizer = None
......@@ -1409,15 +1406,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
# Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
......@@ -1619,12 +1613,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True
if not (self.parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True
if not (self.parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
......@@ -1808,14 +1806,3 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
......@@ -431,8 +431,6 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8 or debug:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
fc1_input_quantizer.all_gather_usage = False
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
quantizer = None
......@@ -1964,15 +1962,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
# Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
......@@ -2193,6 +2188,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.fp8 or self.fp8_calibration:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True
if not self.sequence_parallel:
fc1_input_quantizer.optimize_for_gemm = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage(
rowwise=True,
......@@ -2201,7 +2198,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
(MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer),
),
)
fc1_input_quantizer.internal = True
fc2_input_quantizer.internal = True
fc2_input_quantizer.optimize_for_gemm = True
if fp8_output:
fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT
......@@ -2211,10 +2209,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2
]
fc2_grad_output_quantizer.internal = True
if not self.sequence_parallel:
fc2_grad_output_quantizer.optimize_for_gemm = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
]
fc1_grad_output_quantizer.internal = True
fc1_grad_output_quantizer.optimize_for_gemm = True
return (
fc1_input_quantizer,
......@@ -2467,22 +2468,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].all_gather_usage = True
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
......
......@@ -1313,15 +1313,12 @@ class Linear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
# Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
......@@ -1489,12 +1486,16 @@ class Linear(TransformerEngineBaseModule):
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True
if not (self.parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True
if not (self.parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
return (
......@@ -1669,22 +1670,3 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set compact for inp tensor X
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.parallel_mode == "row":
# set compact for grad_output tensor dY
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].all_gather_usage = True
......@@ -342,15 +342,21 @@ class BasicLinear(BasicOperation):
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)
# Input/grad output quantizers use internal tensors
# Configure input/grad output tensor
# Note: These tensors are only used internally. If there is no
# tensor-parallel communication, they are only used for GEMM.
input_quantizer = self.get_quantizer("forward", 0)
grad_output_quantizer = self.get_quantizer("backward", 0)
if input_quantizer is not None:
input_quantizer.internal = True
if not (self.tensor_parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
if grad_output_quantizer is not None:
grad_output_quantizer.internal = True
if not (self.tensor_parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
# Handle weight quantizer
# Configure weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer = self.get_quantizer("forward", 1)
......
......@@ -292,6 +292,7 @@ class UserbuffersBackwardLinear(FusedOperation):
rowwise=True,
columnwise=with_columnwise,
)
grad_output_quantizer.optimize_for_gemm = False
dy_local = grad_output_quantizer(dy_local)
else:
dy_local = maybe_dequantize(dy_local, dtype)
......
......@@ -294,6 +294,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
columnwise_scale_inv=None,
quantizer=None,
requires_grad=output.requires_grad,
with_gemm_swizzled_scales=False,
)
ctx.save_for_backward(row_id_map, pad_offsets)
......@@ -504,6 +505,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
columnwise_scale_inv=None,
quantizer=None,
requires_grad=act_grad.requires_grad,
with_gemm_swizzled_scales=False,
)
if not ctx.needs_input_grad[2]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment