Unverified Commit 99df8810 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add logic for block-scaled tensors with GEMM swizzled scales (#2486)



* Add general C API for setting tensor params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Implement general accessors for NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor tex swizzling to skip if scales are already swizzled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support pre-swizzled scales in MXFP8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tex function to swizzle MXFP8 scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in inplace swizzle function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments to use "compact/swizzled format"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 quantize kernel with pre-swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose pre-swizzled scales in modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in multi-swizzle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 gated activations with swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Deprecate DSv3-specific quantization logic in C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove support for DSv3 compact data from quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove DSv3 compact data format from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in FP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX to use new swizzled scale API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update C++ swizzle test with swizzled scales API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Return default tensor params when querying params for invalid NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug DSv3 FP8 test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Userbuffers test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure gated activations populate FP8 transpose if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable pre-swizzling with debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts and review suggestions

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use explicitly sized types in config accessors

Miscellaneous review suggestions from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make util header for function that compute swizzled scale index
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from @greptile-apps
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update expected error message in FP8 block-scaling test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @yaox12
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a652730f
...@@ -120,6 +120,7 @@ class Quantizer { ...@@ -120,6 +120,7 @@ class Quantizer {
bool rowwise_usage = true; bool rowwise_usage = true;
bool columnwise_usage = true; bool columnwise_usage = true;
bool internal = false; bool internal = false;
bool optimize_for_gemm = false;
py::handle quantizer; py::handle quantizer;
protected: protected:
...@@ -231,8 +232,6 @@ class Float8BlockQuantizer : public Quantizer { ...@@ -231,8 +232,6 @@ class Float8BlockQuantizer : public Quantizer {
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon. // Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0; float amax_epsilon = 0.0;
// Whether quantized tensor will be used in an all-gather
bool all_gather_usage = false;
private: private:
int block_scaling_dim = 2; int block_scaling_dim = 2;
...@@ -358,11 +357,12 @@ inline size_t typeToNumBits(transformer_engine::DType t) { ...@@ -358,11 +357,12 @@ inline size_t typeToNumBits(transformer_engine::DType t) {
case transformer_engine::DType::kByte: case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3: case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2: case transformer_engine::DType::kFloat8E5M2:
case transformer_engine::DType::kFloat8E8M0:
return 8; return 8;
case transformer_engine::DType::kFloat4E2M1: case transformer_engine::DType::kFloat4E2M1:
return 4; return 4;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
} }
} }
...@@ -386,8 +386,10 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { ...@@ -386,8 +386,10 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) {
return at::kFloat8_e4m3fn; return at::kFloat8_e4m3fn;
case transformer_engine::DType::kFloat8E5M2: case transformer_engine::DType::kFloat8E5M2:
return at::kFloat8_e5m2; return at::kFloat8_e5m2;
case transformer_engine::DType::kFloat8E8M0:
return at::kByte; // e8m0 dtype requires PyTorch 2.7.0+
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
} }
} }
...@@ -414,8 +416,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { ...@@ -414,8 +416,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
case torch::kInt64: case torch::kInt64:
return transformer_engine::DType::kInt64; return transformer_engine::DType::kInt64;
default: default:
std::cout << "Type: " << static_cast<int>(t) << std::endl; NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
NVTE_ERROR("Invalid type");
} }
} }
...@@ -477,7 +478,9 @@ void* getDataPtr(at::Tensor tensor, int offset = 0); ...@@ -477,7 +478,9 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);
std::vector<size_t> convertShape(const NVTEShape& shape); std::vector<size_t> convertShape(const NVTEShape& shape);
size_t roundup(const size_t value, const size_t multiple); size_t roundup(size_t value, size_t multiple);
size_t ceildiv(size_t numer, size_t denom);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
......
...@@ -7,7 +7,12 @@ ...@@ -7,7 +7,12 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#include <map>
#include <optional> #include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "common.h" #include "common.h"
...@@ -78,11 +83,6 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -78,11 +83,6 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph); int64_t window_size_right, bool return_max_logit, bool cuda_graph);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data);
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -475,6 +475,13 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -475,6 +475,13 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list, std::vector<size_t> input_row_list,
std::vector<size_t> unpadded_input_row_list); std::vector<size_t> unpadded_input_row_list);
/***************************************************************************************************
* Scale swizzling for GEMM
**************************************************************************************************/
void inplace_swizzle_scale_for_gemm(py::handle &tensor);
/*************************************************************************************************** /***************************************************************************************************
* NVSHMEM APIs * NVSHMEM APIs
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -327,9 +327,9 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp ...@@ -327,9 +327,9 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
(columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none());
// Construct Python tensor // Construct Python tensor
tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( tensor_py_list.emplace_back(
rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, Float8BlockwiseQTensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale,
quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); fp8_dtype, quantizer_py_list[i], is_2D_scaled));
// Construct C++ tensor // Construct C++ tensor
tensor_cpp_list.emplace_back(makeTransformerEngineTensor( tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
...@@ -365,6 +365,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -365,6 +365,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto fp8_dtype = quantizer_cpp_list[0]->dtype; const auto fp8_dtype = quantizer_cpp_list[0]->dtype;
const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm;
constexpr size_t fp8_elem_size = 1; constexpr size_t fp8_elem_size = 1;
constexpr size_t scale_elem_size = 1; constexpr size_t scale_elem_size = 1;
...@@ -475,8 +477,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -475,8 +477,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
// Construct Python tensor // Construct Python tensor
tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data,
columnwise_scale, fp8_dtype, columnwise_scale, fp8_dtype, quantizer_py_list[i],
quantizer_py_list[i])); with_gemm_swizzled_scales));
// Construct C++ tensor // Construct C++ tensor
tensor_cpp_list.emplace_back(makeTransformerEngineTensor( tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
...@@ -488,6 +490,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -488,6 +490,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0}, rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode)); columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode));
tensor_cpp_list.back().set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
} }
return retval; return retval;
...@@ -517,6 +520,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc ...@@ -517,6 +520,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const auto fp4_dtype = quantizer_cpp_list[0]->dtype;
const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm;
constexpr size_t scale_elem_size = 1; constexpr size_t scale_elem_size = 1;
// Helper function to construct tensor view // Helper function to construct tensor view
...@@ -675,9 +679,9 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc ...@@ -675,9 +679,9 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none();
// Construct Python tensor // Construct Python tensor
tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, tensor_py_list.emplace_back(NVFP4TensorClass(
columnwise_scale, amax_rowwise, amax_columnwise, rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise,
fp4_dtype, quantizer_py_list[i])); amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales));
// Construct C++ tensor // Construct C++ tensor
// Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor,
...@@ -693,6 +697,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc ...@@ -693,6 +697,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0}, rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode); columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode);
tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
// Set the amax rowwise and amax columnwise if available // Set the amax rowwise and amax columnwise if available
if (rowwise_usage) { if (rowwise_usage) {
...@@ -703,6 +708,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc ...@@ -703,6 +708,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
} }
tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); tensor_cpp_list.emplace_back(std::move(tensor_wrapper));
} }
} }
......
...@@ -240,9 +240,12 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -240,9 +240,12 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
auto main_stream = at::cuda::getCurrentCUDAStream(); auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors // Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); auto [A_row_scales, A_col_scales] = swizzle_scales_for_gemm(A_tensor, transa, !transa);
swizzled_scale_inverses_list.emplace_back( auto [B_row_scales, B_col_scales] = swizzle_scales_for_gemm(B_tensor, !transb, transb);
std::move(swizzle_scaling_factors(B_tensor, !transb))); swizzled_scale_inverses_list.emplace_back(std::move(A_row_scales));
swizzled_scale_inverses_list.emplace_back(std::move(A_col_scales));
swizzled_scale_inverses_list.emplace_back(std::move(B_row_scales));
swizzled_scale_inverses_list.emplace_back(std::move(B_col_scales));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt // as it is not natively supported by cublasLt
...@@ -501,9 +504,9 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -501,9 +504,9 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
// Optionally swizzle the scaling factors // Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back( swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa)); multi_tensor_swizzle_scales_for_gemm(te_A_wrappers, transa, !transa));
swizzled_scale_inverses_list.emplace_back( swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb)); multi_tensor_swizzle_scales_for_gemm(te_B_wrappers, !transb, transb));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt // as it is not natively supported by cublasLt
......
...@@ -89,14 +89,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -89,14 +89,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py); TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py);
TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor // Quantizer
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_nvte;
if (out.is_none()) {
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Choose implementation // Choose implementation
enum class Impl { enum class Impl {
...@@ -135,6 +129,19 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -135,6 +129,19 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
} }
} }
// Output tensor
TensorWrapper out_nvte;
if (out.is_none()) {
if (impl == Impl::FULLY_FUSED) {
// FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN
// kernel does not support GEMM swizzled scales
quantizer_cpp->optimize_for_gemm = false;
}
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Construct unquantized output tensor if needed // Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte; TensorWrapper unquantized_out_nvte;
py::object unquantized_out; py::object unquantized_out;
...@@ -318,14 +325,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -318,14 +325,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat)); at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor // Quantizer
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_nvte;
if (out.is_none()) {
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Choose implementation // Choose implementation
enum class Impl { enum class Impl {
...@@ -364,6 +365,19 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -364,6 +365,19 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
} }
} }
// Output tensor
TensorWrapper out_nvte;
if (out.is_none()) {
if (impl == Impl::FULLY_FUSED) {
// FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN
// kernel does not support GEMM swizzled scales
quantizer_cpp->optimize_for_gemm = false;
}
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Construct unquantized output tensor if needed // Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte; TensorWrapper unquantized_out_nvte;
py::object unquantized_out; py::object unquantized_out;
......
...@@ -290,6 +290,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -290,6 +290,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>()); "Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding,
"Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>()); "Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>());
m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm,
"Convert tensor block scales into GEMM swizzled format");
// attention kernels // attention kernels
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "common.h"
#include "common/common.h"
#include "extensions.h"
#include "pybind.h"
#include "util.h"
namespace transformer_engine {
namespace pytorch {
namespace {
void reset_tensor_data(transformer_engine::TensorWrapper &tensor, bool rowwise, bool columnwise) {
NVTEShape shape;
shape.ndim = 1;
shape.data[0] = 0;
const transformer_engine::DType dtype = transformer_engine::DType::kFloat32;
if (rowwise) {
tensor.set_rowwise_data(nullptr, dtype, shape);
tensor.set_rowwise_scale_inv(nullptr, dtype, shape);
}
if (columnwise) {
tensor.set_columnwise_data(nullptr, dtype, shape);
tensor.set_columnwise_scale_inv(nullptr, dtype, shape);
}
}
} // namespace
std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm(
transformer_engine::TensorWrapper &tensor, bool rowwise_usage, bool columnwise_usage) {
// Return early if scale swizzling is not required
const auto scaling_mode = tensor.scaling_mode();
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return {std::nullopt, std::nullopt};
}
// Return early if scales are already swizzled
if (tensor.get_with_gemm_swizzled_scales()) {
return {std::nullopt, std::nullopt};
}
// CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();
// Swizzle row-wise scales if needed
std::optional<at::Tensor> rowwise_scales_pyt;
if (rowwise_usage) {
// Buffer for unswizzled scales
const auto input_scales_nvte = tensor.get_rowwise_scale_inv();
void *input_scales_dptr = input_scales_nvte.data_ptr;
const NVTEShape input_scales_shape = input_scales_nvte.shape;
const auto scales_dtype = static_cast<DType>(input_scales_nvte.dtype);
// Allocate buffer for swizzled scales
const NVTEShape output_scales_shape = input_scales_shape;
rowwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false);
void *output_scales_dptr = getDataPtr(*rowwise_scales_pyt);
// Initialize TE tensors with scales
const auto data_nvte = tensor.get_rowwise_data();
const auto data_dtype = static_cast<DType>(data_nvte.dtype);
TensorWrapper input_nvte(scaling_mode);
input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_rowwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape);
TensorWrapper output_nvte(scaling_mode);
output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
output_nvte.set_with_gemm_swizzled_scales(true);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE(
{ nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); });
// Update tensor with swizzled scales
tensor.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
}
// Swizzle column-wise scales if needed
std::optional<at::Tensor> columnwise_scales_pyt;
if (columnwise_usage) {
// Buffer for unswizzled scales
const auto input_scales_nvte = tensor.get_columnwise_scale_inv();
void *input_scales_dptr = input_scales_nvte.data_ptr;
const NVTEShape input_scales_shape = input_scales_nvte.shape;
const auto scales_dtype = static_cast<DType>(input_scales_nvte.dtype);
// Allocate buffer for swizzled scales
const NVTEShape output_scales_shape = input_scales_shape;
columnwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false);
void *output_scales_dptr = getDataPtr(*columnwise_scales_pyt);
// Initialize TE tensors with scales
const auto data_nvte = tensor.get_columnwise_data();
const auto data_dtype = static_cast<DType>(data_nvte.dtype);
TensorWrapper input_nvte(scaling_mode);
input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_columnwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape);
TensorWrapper output_nvte(scaling_mode);
output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
output_nvte.set_with_gemm_swizzled_scales(true);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE(
{ nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); });
// Update tensor with swizzled scales
tensor.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
}
// Update tensor
reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage);
tensor.set_with_gemm_swizzled_scales(true);
return {std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)};
}
std::optional<at::Tensor> multi_tensor_swizzle_scales_for_gemm(
std::vector<transformer_engine::TensorWrapper> &tensors, bool rowwise_usage,
bool columnwise_usage) {
// Checks and trivial cases
NVTE_CHECK(rowwise_usage != columnwise_usage,
"Expect exactly one of rowwise_usage=", rowwise_usage,
" and columnwise_usage=", columnwise_usage, ".");
if (tensors.empty()) {
return std::nullopt;
}
const auto scaling_mode = tensors.front().scaling_mode();
for (const auto &tensor : tensors) {
NVTE_CHECK(tensor.scaling_mode() == scaling_mode, "Tensors have different scaling modes");
}
// Return early if scale swizzling is not required
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return std::nullopt;
}
// Filter out tensors that already have swizzled scales
std::vector<TensorWrapper *> tensors_needing_swizzle;
for (auto &tensor : tensors) {
if (!tensor.get_with_gemm_swizzled_scales()) {
tensors_needing_swizzle.push_back(&tensor);
}
}
if (tensors_needing_swizzle.empty()) {
return std::nullopt;
}
// Determine buffer size needed for swizzled scales
std::vector<size_t> output_scales_offsets;
size_t output_scales_bytes = 0;
for (auto &tensor : tensors_needing_swizzle) {
const auto scales_nvte =
(rowwise_usage ? tensor->get_rowwise_scale_inv() : tensor->get_columnwise_scale_inv());
const auto &shape = scales_nvte.shape;
const auto dtype = static_cast<DType>(scales_nvte.dtype);
const auto dtype_bits = transformer_engine::pytorch::typeToNumBits(dtype);
const auto size = product(shape, 0, shape.ndim);
output_scales_bytes = roundup(output_scales_bytes, 16); // align to 16B
output_scales_offsets.push_back(output_scales_bytes);
output_scales_bytes += ceildiv(size * dtype_bits, 8);
}
// Allocate buffer for swizzled scales
auto output_scales_pyt = allocateSpace(std::vector<size_t>{output_scales_bytes},
transformer_engine::DType::kByte, false);
uint8_t *output_scales_dptr = reinterpret_cast<uint8_t *>(getDataPtr(output_scales_pyt));
// Construct TE tensors with only scales
std::vector<transformer_engine::TensorWrapper> inputs_nvte, outputs_nvte;
for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) {
auto &tensor = *tensors_needing_swizzle[i];
inputs_nvte.emplace_back(scaling_mode);
outputs_nvte.emplace_back(scaling_mode);
auto &input_nvte = inputs_nvte.back();
auto &output_nvte = outputs_nvte.back();
output_nvte.set_with_gemm_swizzled_scales(true);
if (rowwise_usage) {
const auto data_nvte = tensor.get_rowwise_data();
const auto scales_nvte = tensor.get_rowwise_scale_inv();
const auto data_dtype = static_cast<transformer_engine::DType>(data_nvte.dtype);
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_rowwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape);
output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
} else {
const auto data_nvte = tensor.get_columnwise_data();
const auto scales_nvte = tensor.get_columnwise_scale_inv();
const auto data_dtype = static_cast<transformer_engine::DType>(data_nvte.dtype);
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_columnwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape);
output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i],
scales_dtype, scales_nvte.shape);
}
}
// Pack raw NVTETensors into vectors
std::vector<NVTETensor> inputs_nvte_raw, outputs_nvte_raw;
for (auto &tensor : inputs_nvte) {
inputs_nvte_raw.emplace_back(tensor.data());
}
for (auto &tensor : outputs_nvte) {
outputs_nvte_raw.emplace_back(tensor.data());
}
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(),
inputs_nvte_raw.size(),
at::cuda::getCurrentCUDAStream());
});
// Update tensors with swizzled scales
for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) {
auto &tensor = *tensors_needing_swizzle[i];
reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage);
tensor.set_with_gemm_swizzled_scales(true);
if (rowwise_usage) {
auto scales_nvte = outputs_nvte[i].get_rowwise_scale_inv();
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
tensor.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
} else {
auto scales_nvte = outputs_nvte[i].get_columnwise_scale_inv();
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
tensor.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
}
}
return std::move(output_scales_pyt);
}
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input,
bool rowwise) {
// Check input tensor
const NVTEScalingMode scaling_mode = input.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (size_t i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (size_t i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
data_shape.ndim = 2;
// Recreate input tensor with rowwise usage
transformer_engine::TensorWrapper input_cu(scaling_mode);
input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
const NVTEBasicTensor scale_inv =
rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
input_cu.set_rowwise_scale_inv(
scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);
// Create output tensor
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
// Output swizzled mxfp8 scaling factor dimensions
const size_t swizzled_scale_inv_first_dim = ceildiv(data_flat_first_dim, 128) * 128;
const size_t swizzled_scale_inv_last_dim = ceildiv(data_flat_last_dim, 128) * 4;
// Allocate memory for swizzled mxfp8 scaling factors
at::Tensor swizzled_scale_inv =
allocateSpace(std::vector<size_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim},
transformer_engine::DType::kByte, false);
// Set rowwise scaling factors on output
void *const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{};
swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
swizzled_scale_inv_shape.ndim = 2;
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
swizzled_scale_inv_shape);
output_cu.set_with_gemm_swizzled_scales(true);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
NVTE_SCOPED_GIL_RELEASE({
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input = std::move(output_cu);
return swizzled_scale_inv;
}
void inplace_swizzle_scale_for_gemm(py::handle &tensor) {
// Convert Python tensor to C++ tensor
auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none());
// Return early if scale swizzling is not required
const auto scaling_mode = tensor_nvte.scaling_mode();
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return;
}
// Return early if scales are already swizzled
if (tensor_nvte.get_with_gemm_swizzled_scales()) {
return;
}
// Check what scaling factors the tensor contains
auto is_empty = [](const NVTEBasicTensor &t) -> bool {
return t.shape.ndim == 1 && t.shape.data[0] == 0;
};
const bool has_rowwise_scales = !is_empty(tensor_nvte.get_rowwise_scale_inv());
const bool has_columnwise_scales = !is_empty(tensor_nvte.get_columnwise_scale_inv());
// Swizzle scaling factors
auto [rowwise_scales, columnwise_scales] =
swizzle_scales_for_gemm(tensor_nvte, has_rowwise_scales, has_columnwise_scales);
// Update Python tensor with swizzled scales
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
if (has_rowwise_scales) {
tensor.attr("_rowwise_scale_inv") = rowwise_scales;
}
if (has_columnwise_scales) {
tensor.attr("_columnwise_scale_inv") = columnwise_scales;
}
tensor.attr("_with_gemm_swizzled_scales") = true;
break;
case NVTE_NVFP4_1D_SCALING:
if (has_rowwise_scales) {
tensor.attr("_rowwise_scale_inv") = rowwise_scales;
}
if (has_columnwise_scales) {
tensor.attr("_columnwise_scale_inv") = columnwise_scales;
}
tensor.attr("_with_gemm_swizzled_scales") = true;
break;
default:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
}
}
} // namespace pytorch
} // namespace transformer_engine
...@@ -52,10 +52,12 @@ Quantizer::Quantizer(const py::handle& quantizer) { ...@@ -52,10 +52,12 @@ Quantizer::Quantizer(const py::handle& quantizer) {
this->rowwise_usage = true; this->rowwise_usage = true;
this->columnwise_usage = true; this->columnwise_usage = true;
this->internal = false; this->internal = false;
this->optimize_for_gemm = false;
} else { } else {
this->rowwise_usage = quantizer.attr("rowwise_usage").cast<bool>(); this->rowwise_usage = quantizer.attr("rowwise_usage").cast<bool>();
this->columnwise_usage = quantizer.attr("columnwise_usage").cast<bool>(); this->columnwise_usage = quantizer.attr("columnwise_usage").cast<bool>();
this->internal = quantizer.attr("internal").cast<bool>(); this->internal = quantizer.attr("internal").cast<bool>();
this->optimize_for_gemm = quantizer.attr("optimize_for_gemm").cast<bool>();
this->quantizer = quantizer; this->quantizer = quantizer;
} }
} }
...@@ -555,7 +557,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti ...@@ -555,7 +557,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>(); this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim."); "Unsupported block scaling dim.");
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
} }
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {}
...@@ -575,10 +576,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -575,10 +576,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) { if (rowwise_usage) {
data_rowwise = at::empty(torch_shape, opts); data_rowwise = at::empty(torch_shape, opts);
auto scale_shape = get_scale_shape(shape, false); auto scale_shape = get_scale_shape(shape, false);
...@@ -597,21 +594,13 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -597,21 +594,13 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape); columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) { if (torch_shape.size() > 0) {
if (!all_gather_usage) { torch_columnwise_shape.reserve(torch_shape.size());
torch_columnwise_shape.reserve(torch_shape.size()); columnwise_shape.reserve(shape.size());
columnwise_shape.reserve(shape.size()); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); columnwise_shape.push_back(shape[shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]); for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
for (size_t i = 0; i < torch_shape.size() - 1; ++i) { torch_columnwise_shape.push_back(torch_shape[i]);
torch_columnwise_shape.push_back(torch_shape[i]); columnwise_shape.push_back(shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
} }
} }
auto scale_shape = get_scale_shape(shape, true); auto scale_shape = get_scale_shape(shape, true);
...@@ -635,7 +624,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -635,7 +624,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); "is_2D_scaled"_a = (block_scaling_dim == 2));
} else { } else {
py::handle Float8BlockwiseQTensorClass( py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass)); reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
...@@ -643,8 +632,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -643,8 +632,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
"data_format"_a = data_format);
} }
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
...@@ -654,6 +642,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -654,6 +642,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
py::object tensor) const { py::object tensor) const {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>(); const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
const bool with_gemm_swizzled_scales = true;
// Extract buffers from Python tensor // Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> { auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
...@@ -675,13 +664,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -675,13 +664,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t> { auto get_columnwise_shape = [&columnwise_data]() -> std::vector<size_t> {
if (!columnwise_data) { if (!columnwise_data) {
return std::vector<size_t>(); return std::vector<size_t>();
} }
if (all_gather_usage) {
return getTensorShape(*columnwise_data);
}
std::vector<size_t> shape = getTensorShape(*columnwise_data); std::vector<size_t> shape = getTensorShape(*columnwise_data);
std::vector<size_t> shape_transposed(shape.size()); std::vector<size_t> shape_transposed(shape.size());
for (size_t i = 0; i + 1 < shape.size(); ++i) { for (size_t i = 0; i + 1 < shape.size(); ++i) {
...@@ -696,12 +682,12 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -696,12 +682,12 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
if (rowwise_data) { if (rowwise_data) {
shape = getTensorShape(*rowwise_data); shape = getTensorShape(*rowwise_data);
if (columnwise_data) { if (columnwise_data) {
auto expected_shape = get_columnwise_shape(all_gather_usage); auto expected_shape = get_columnwise_shape();
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
") and column-wise data (shape=", expected_shape, ") do not match"); ") and column-wise data (shape=", expected_shape, ") do not match");
} }
} else { } else {
shape = get_columnwise_shape(all_gather_usage); shape = get_columnwise_shape();
} }
std::vector<int64_t> torch_shape; std::vector<int64_t> torch_shape;
for (auto s : shape) { for (auto s : shape) {
...@@ -738,21 +724,13 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -738,21 +724,13 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
std::vector<size_t> columnwise_shape; std::vector<size_t> columnwise_shape;
std::vector<int64_t> torch_columnwise_shape; std::vector<int64_t> torch_columnwise_shape;
if (torch_shape.size() > 0) { if (torch_shape.size() > 0) {
if (!all_gather_usage) { torch_columnwise_shape.reserve(torch_shape.size());
torch_columnwise_shape.reserve(torch_shape.size()); columnwise_shape.reserve(shape.size());
columnwise_shape.reserve(shape.size()); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); columnwise_shape.push_back(shape[shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]); for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
for (size_t i = 0; i < torch_shape.size() - 1; ++i) { torch_columnwise_shape.push_back(torch_shape[i]);
torch_columnwise_shape.push_back(torch_shape[i]); columnwise_shape.push_back(shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
} }
} }
if (!columnwise_data) { if (!columnwise_data) {
...@@ -798,6 +776,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -798,6 +776,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
} }
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
set_quantization_params(&ret); set_quantization_params(&ret);
return {std::move(ret), std::move(tensor)}; return {std::move(ret), std::move(tensor)};
} }
...@@ -813,9 +792,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o ...@@ -813,9 +792,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o
} }
quant_config.set_force_pow_2_scales(force_pow_2_scales); quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon); quant_config.set_amax_epsilon(amax_epsilon);
if (all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
}); });
...@@ -832,10 +808,6 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size ...@@ -832,10 +808,6 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t m_dim = numel / k_dim; size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128; constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
std::vector<size_t> scale_shape; std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise; bool rowwise_usage = !columnwise;
...@@ -845,26 +817,17 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size ...@@ -845,26 +817,17 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t sinv0 = 0; size_t sinv0 = 0;
size_t sinv1 = 0; size_t sinv1 = 0;
if (block_scaling_dim == 2) { if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now sinv0 = ceildiv(m_dim, kBlockLen);
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, sinv1 = roundup(ceildiv(k_dim, kBlockLen), 4);
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) { } else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv0 = ceildiv(k_dim, kBlockLen);
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); sinv1 = roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else { } else {
NVTE_CHECK(false, NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise." "Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ", "Expected 1 or 2. Got ",
block_scaling_dim); block_scaling_dim);
} }
scale_shape = {sinv0, sinv1}; scale_shape = {sinv0, sinv1};
} else { } else {
...@@ -872,24 +835,16 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size ...@@ -872,24 +835,16 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t sinv0 = 0; size_t sinv0 = 0;
size_t sinv1 = 0; size_t sinv1 = 0;
if (block_scaling_dim == 2) { if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now sinv0 = ceildiv(k_dim, kBlockLen);
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, sinv1 = roundup(ceildiv(m_dim, kBlockLen), 4);
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) { } else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT sinv0 = ceildiv(m_dim, kBlockLen);
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; sinv1 = roundup(k_dim, 4);
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else { } else {
NVTE_CHECK(false, NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise." "Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ", "Expected 1 or 2. Got ",
block_scaling_dim); block_scaling_dim);
} }
scale_shape = {sinv0, sinv1}; scale_shape = {sinv0, sinv1};
} }
...@@ -906,6 +861,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -906,6 +861,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
DType dtype) const { DType dtype) const {
using namespace pybind11::literals; using namespace pybind11::literals;
// Scaling factor format
const bool with_gemm_swizzled_scales = this->optimize_for_gemm;
// Tensor dimensions // Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end()); const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1; size_t flat_first_dim = 1;
...@@ -951,19 +909,17 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -951,19 +909,17 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
py::object out_py; py::object out_py;
if (internal) { if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass)); py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass));
out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, out_py = MXFP8TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py,
"columnwise_data"_a = columnwise_data_py, columnwise_scale_inv_py, this->dtype, this->quantizer,
"rowwise_scale_inv"_a = rowwise_scale_inv_py, with_gemm_swizzled_scales);
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
} else { } else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass)); py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), out_py = MXFP8TensorClass(
"rowwise_data"_a = rowwise_data_py, "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"columnwise_data"_a = columnwise_data_py, "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, "fp8_dtype"_a = this->dtype,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales);
} }
// Construct C++ MXFP8 tensor // Construct C++ MXFP8 tensor
...@@ -978,6 +934,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -978,6 +934,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
columnwise_scale_inv_shape); columnwise_scale_inv_shape);
} }
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp); this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
...@@ -987,6 +944,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor( ...@@ -987,6 +944,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
py::object tensor) const { py::object tensor) const {
NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor.");
// Scaling factor format
const bool with_gemm_swizzled_scales = this->optimize_for_gemm;
// Extract buffers from Python tensor // Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> { auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name); auto attr_py = tensor.attr(name);
...@@ -1070,6 +1030,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor( ...@@ -1070,6 +1030,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
// Coerce other attrs // Coerce other attrs
tensor.attr("_fp8_dtype") = dtype; tensor.attr("_fp8_dtype") = dtype;
tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales;
// Construct C++ MXFP8 tensor // Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
...@@ -1083,6 +1044,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor( ...@@ -1083,6 +1044,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*columnwise_scale_inv)); getTensorShape(*columnwise_scale_inv));
} }
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp); this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)}; return {std::move(out_cpp), std::move(tensor)};
...@@ -1173,6 +1135,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1173,6 +1135,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
DType dtype) const { DType dtype) const {
using namespace pybind11::literals; using namespace pybind11::literals;
// Scaling factor format
const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm
// Tensor dimensions // Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end()); const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1; size_t flat_first_dim = 1;
...@@ -1235,12 +1200,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1235,12 +1200,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
py::object out_py; py::object out_py;
if (internal) { if (internal) {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass)); py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass));
out_py = NVFP4TensorClass( out_py = NVFP4TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py,
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, columnwise_scale_inv_py, amax_rowwise_py, amax_columnwise_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py, this->dtype, this->quantizer, with_gemm_swizzled_scales);
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
} else { } else {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass)); py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass));
out_py = NVFP4TensorClass( out_py = NVFP4TensorClass(
...@@ -1249,7 +1211,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1249,7 +1211,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
"rowwise_scale_inv"_a = rowwise_scale_inv_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer); "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales);
} }
// Construct C++ tensor // Construct C++ tensor
...@@ -1272,6 +1234,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1272,6 +1234,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
} }
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp); this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
...@@ -1301,6 +1264,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor( ...@@ -1301,6 +1264,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
py::object tensor) const { py::object tensor) const {
NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor.");
// Scaling factor format
const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm
// Extract buffers from Python tensor // Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> { auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name); auto attr_py = tensor.attr(name);
...@@ -1438,6 +1404,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor( ...@@ -1438,6 +1404,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
} }
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp); this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)}; return {std::move(out_cpp), std::move(tensor)};
......
...@@ -55,8 +55,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer ...@@ -55,8 +55,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer
TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) {
auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast<bool>();
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor."); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor.");
...@@ -78,6 +79,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) ...@@ -78,6 +79,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
getTensorShape(scale_inv)); getTensorShape(scale_inv));
} }
// Scale layout
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
// Quantizer state // Quantizer state
quantizer->set_quantization_params(&ret); quantizer->set_quantization_params(&ret);
...@@ -93,6 +97,7 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer ...@@ -93,6 +97,7 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
// Row-wise data
if (rowwise_usage) { if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>(); const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>(); const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
...@@ -102,6 +107,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer ...@@ -102,6 +107,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
} }
// Column-wise data
if (columnwise_usage) { if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>(); const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>(); const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
...@@ -112,7 +119,10 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer ...@@ -112,7 +119,10 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
} }
// Quantizer state
quantizer->set_quantization_params(&ret); quantizer->set_quantization_params(&ret);
return ret; return ret;
} }
...@@ -121,8 +131,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) ...@@ -121,8 +131,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING); auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast<bool>();
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor.");
...@@ -150,6 +161,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) ...@@ -150,6 +161,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
getTensorShape(amax_columnwise)); getTensorShape(amax_columnwise));
} }
// Scale layout
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
// Quantizer state // Quantizer state
quantizer->set_quantization_params(&ret); quantizer->set_quantization_params(&ret);
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "util.h"
#include "common.h"
#include "common/common.h"
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING &&
input.scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt;
}
NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8,
"4-bit or 8-bit input required for swizzling scaling factors.");
const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING;
NVTEBasicTensor scale_inv;
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = input.shape();
scale_inv = input.get_rowwise_scale_inv();
} else {
nvte_input_shape = input.get_columnwise_data().shape;
scale_inv = input.get_columnwise_scale_inv();
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape.");
// Allocate memory for swizzled output.
auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
std::vector<int64_t> scale_inv_shape_int;
for (size_t i = 0; i < scale_inv_shape.size(); ++i) {
scale_inv_shape_int.push_back(static_cast<int64_t>(scale_inv_shape[i]));
}
auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options);
void* scale_inv_dptr = scale_inv.data_ptr;
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
transformer_engine::TensorWrapper input_cu(input.scaling_mode());
transformer_engine::TensorWrapper output_cu(input.scaling_mode());
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
} else {
input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
}
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
}
return swizzled_scale_inv;
}
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper>& tensors, bool rowwise) {
using namespace transformer_engine::pytorch;
if (tensors.empty()) {
return std::nullopt;
}
bool all_same_scaling_mode = std::all_of(
tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) {
return val.scaling_mode() == tensors.front().scaling_mode();
});
NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same.");
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING &&
tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt;
}
const auto scaling_mode = tensors.front().scaling_mode();
const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors;
// Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
std::vector<std::vector<size_t>> scale_inv_shapes;
std::vector<void*> scale_inv_dptrs;
size_t buffer_size = 0;
std::vector<size_t> scale_inv_offsets;
constexpr size_t scale_elem_size = 1;
for (auto& tensor : tensors) {
NVTEBasicTensor scale_inv;
if (rowwise) {
scale_inv = tensor.get_rowwise_scale_inv();
} else {
scale_inv = tensor.get_columnwise_scale_inv();
}
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_inv_offsets.push_back(buffer_size);
buffer_size += product(scale_inv_shape) * scale_elem_size;
scale_inv_shapes.emplace_back(scale_inv_shape);
scale_inv_dptrs.push_back(scale_inv.data_ptr);
}
// Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
// Empty tensors don't require scale swizzling
if (tensor.numel() == 0) {
continue;
}
// Tensor shape
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = tensor.shape();
} else {
nvte_input_shape = tensor.get_columnwise_data().shape;
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(scaling_mode);
transformer_engine::TensorWrapper output_cu(scaling_mode);
if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
} else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
}
input_tensors.emplace_back(input_cu.data());
output_tensors.emplace_back(output_cu.data());
wrappers.emplace_back(std::move(input_cu));
wrappers.emplace_back(std::move(output_cu));
}
// Launch kernel
nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(),
input_tensors.size(), at::cuda::getCurrentCUDAStream());
return buffer;
}
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
using transformer_engine::DIVUP;
// Check input tensor
const NVTEScalingMode scaling_mode = input.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (size_t i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (size_t i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
data_shape.ndim = 2;
// Recreate input tensor with rowwise usage
transformer_engine::TensorWrapper input_cu(scaling_mode);
input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
const NVTEBasicTensor scale_inv =
rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
input_cu.set_rowwise_scale_inv(
scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);
// Create output tensor
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
// Output swizzled mxfp8 scaling factor dimensions
const size_t swizzled_scale_inv_first_dim = DIVUP<size_t>(data_flat_first_dim, 128) * 128;
const size_t swizzled_scale_inv_last_dim = DIVUP<size_t>(data_flat_last_dim, 128) * 4;
// Allocate memory for swizzled mxfp8 scaling factors
const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
at::Tensor swizzled_scale_inv = at::empty(
std::vector<int64_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options);
// Set rowwise scaling factors on output
void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{};
swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
swizzled_scale_inv_shape.ndim = 2;
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
swizzled_scale_inv_shape);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input = std::move(output_cu);
return swizzled_scale_inv;
}
...@@ -10,33 +10,44 @@ ...@@ -10,33 +10,44 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <optional> #include <optional>
#include <tuple>
#include <vector>
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
/*! \brief Swizzle the scaling factor of the input tensor. namespace transformer_engine {
namespace pytorch {
/*! \brief Convert tensor block scales into GEMM swizzled format.
* *
* The returned swizzled scaling factor tensor should be kept alive during the GEMM. * The returned swizzled scales should be kept alive during the GEMM.
*/ */
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input, std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm(
bool rowwise); TensorWrapper& tensor, bool rowwise_usage, bool columnwise_usage);
/*! \brief Swizzle the scaling factor of the input tensors. /*! \brief Convert multiple tensor block scales into GEMM swizzled format.
* *
* The returned swizzled scaling factor tensors should be kept alive during the GEMMs. * The returned swizzled scales should be kept alive during the GEMMs.
*/ */
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( std::optional<at::Tensor> multi_tensor_swizzle_scales_for_gemm(std::vector<TensorWrapper>& tensors,
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise); bool rowwise_usage,
bool columnwise_usage);
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. /*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
* *
* If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid * If rowwise==false, the columnwise data will be reinterpreted as
* transposing it in memory. Due to differences in how block scaling and mxfp8 store data, * rowwise data to avoid transposing it in memory. Due to differences
* this requires the calling code to treat the output tensor as having been tranposed in this case. * in how block scaling and mxfp8 store data, this requires the
* calling code to treat the output tensor as having been transposed
* in this case.
* *
* Returns the swizzled scaling factor of the converted mxfp8 tensor. * Returns the swizzled scaling factor of the converted mxfp8 tensor.
* The returned swizzled scaling factor tensor should be kept alive during the GEMM. * The returned swizzled scaling factor tensor should be kept alive
* during the GEMM.
*/ */
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise);
bool rowwise);
} // namespace pytorch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
...@@ -48,7 +48,7 @@ from .tensor.storage.float8_tensor_storage import Float8TensorStorage ...@@ -48,7 +48,7 @@ from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
__all__ = ["checkpoint", "CudaRNGStatesTracker"] __all__ = ["checkpoint", "CudaRNGStatesTracker"]
...@@ -930,6 +930,34 @@ def reduce_scatter_along_first_dim( ...@@ -930,6 +930,34 @@ def reduce_scatter_along_first_dim(
return output, handle return output, handle
@dataclass
class _AsyncHandle:
"""Handle for asynchronous collectives."""
async_handle: torch.distributed.Work
post_process_function: Optional[Callable] = None
post_process_function_args: Optional[Tuple[Any, ...]] = None
post_process_function_kwargs: Optional[Dict[str, Any]] = None
_synchronized: bool = False
def wait(self) -> None:
"""Synchronize the asynchronous communicaton.
Perform post-processing if needed.
"""
if self._synchronized:
return
self.async_handle.wait()
if self.post_process_function is not None:
args = self.post_process_function_args
args = () if args is None else args
kwargs = self.post_process_function_kwargs
kwargs = {} if kwargs is None else kwargs
self.post_process_function(*args, **kwargs)
self._synchronized = True
def _all_gather_fp8( def _all_gather_fp8(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
...@@ -1020,73 +1048,7 @@ def _all_gather_fp8( ...@@ -1020,73 +1048,7 @@ def _all_gather_fp8(
return out, handle return out, handle
def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]: def _start_all_gather_fp8_blockwise(
"""Get quantizer format."""
if isinstance(quantizer, DebugQuantizer):
quantizer = quantizer.parent_quantizer
if isinstance(quantizer, Float8BlockQuantizer):
return quantizer.all_gather_usage
return None
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
if isinstance(quantizer, DebugQuantizer):
_quantizer = quantizer.parent_quantizer
if isinstance(_quantizer, Float8BlockQuantizer):
_quantizer.all_gather_usage = compact
def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorStorage,
quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorStorage:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
if out._is_gemm_ready_format():
return out
needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage
need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if needs_columnwise_data_transpose:
out._transpose_columnwise_data()
if need_rowwise_scale_transpose:
out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous()
out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY
return out
@dataclass
class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorStorage
quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_fp8_blockwise_gather(self.tensor, self.quantizer)
self._synchronized = True
def _all_gather_fp8_blockwise(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
*, *,
...@@ -1125,44 +1087,25 @@ def _all_gather_fp8_blockwise( ...@@ -1125,44 +1087,25 @@ def _all_gather_fp8_blockwise(
) )
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")
# Output tensor dims # Output tensor dims
if out_shape is None: if out_shape is None:
out_shape = list(inp.size()) out_shape = list(inp.size())
out_shape[0] *= world_size out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler # Check that quantizer is valid
if ( if quantizer is None:
not isinstance(inp, Float8BlockwiseQTensorStorage) raise ValueError("Quantizer is missing")
and quantizer is not None if not isinstance(quantizer, Float8BlockQuantizer):
and not quantizer.is_quantizable(inp) raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
):
out = torch.empty( # Fall back to high-precision all-gather if FP8 is not supported
out_shape, if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
dtype=dtype, out = torch.empty(out_shape, dtype=dtype, device=device)
device=device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = False
out = quantizer(out) out = quantizer(out)
quantizer.all_gather_usage = orig_all_gather_usage
return out, None return out, None
# Implementation of fp8 gather needs to account for: # Quantize input tensor if needed
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# Cast input tensor to Float8BlockwiseQTensor with required data
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True
if not isinstance(inp, Float8BlockwiseQTensorStorage): if not isinstance(inp, Float8BlockwiseQTensorStorage):
inp = quantizer(inp) inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
...@@ -1177,14 +1120,9 @@ def _all_gather_fp8_blockwise( ...@@ -1177,14 +1120,9 @@ def _all_gather_fp8_blockwise(
# Construct Float8BlockwiseQTensor output tensor # Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.all_gather_usage = orig_all_gather_usage # Temporary buffers for all-gathering transposed buffers
interleaved_rowwise_scale_inv = None
# Begin to do network communication, need to make sure compact format interleaved_columnwise_data = None
if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT:
raise RuntimeError(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f"but found data_format={inp._data_format}"
)
# Coalesce NCCL collectives # Coalesce NCCL collectives
with torch.distributed._coalescing_manager( with torch.distributed._coalescing_manager(
...@@ -1193,11 +1131,17 @@ def _all_gather_fp8_blockwise( ...@@ -1193,11 +1131,17 @@ def _all_gather_fp8_blockwise(
async_ops=async_op, async_ops=async_op,
) as coalescing_manager: ) as coalescing_manager:
# Gather Float8BlockwiseQTensor data for row-wise usage # Gather row-wise data
if quantizer.rowwise_usage: if quantizer.rowwise_usage:
# Launch all-gathers scale_inv_shape = list(inp._rowwise_scale_inv.size())
scale_inv_shape[0] *= world_size
interleaved_rowwise_scale_inv = torch.empty(
scale_inv_shape,
dtype=inp._rowwise_scale_inv.dtype,
device=device,
)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out._rowwise_scale_inv, interleaved_rowwise_scale_inv,
inp._rowwise_scale_inv, inp._rowwise_scale_inv,
group=process_group, group=process_group,
) )
...@@ -1207,36 +1151,73 @@ def _all_gather_fp8_blockwise( ...@@ -1207,36 +1151,73 @@ def _all_gather_fp8_blockwise(
group=process_group, group=process_group,
) )
# Gather Float8BlockwiseQTensor data for column-wise usage # Column-wise data
if quantizer.columnwise_usage: if quantizer.columnwise_usage:
# Launch all-gathers data_shape = list(inp._columnwise_data.size())
data_shape[0] *= world_size
interleaved_columnwise_data = torch.empty(
data_shape,
dtype=inp._columnwise_data.dtype,
device=device,
)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out._columnwise_scale_inv, out._columnwise_scale_inv,
inp._columnwise_scale_inv, inp._columnwise_scale_inv,
group=process_group, group=process_group,
) )
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
out._columnwise_data, interleaved_columnwise_data,
inp._columnwise_data, inp._columnwise_data,
group=process_group, group=process_group,
) )
handle = coalescing_manager if async_op else None # Finalize communication if needed
async_handle = None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if async_op: if async_op:
handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle) async_handle = _AsyncHandle(
coalescing_manager,
post_process_function=_finish_all_gather_fp8_blockwise,
post_process_function_args=(
out,
world_size,
interleaved_rowwise_scale_inv,
interleaved_columnwise_data,
),
)
else: else:
# if it's a sync op, we need to do the transpose here as post processing step _finish_all_gather_fp8_blockwise(
_post_process_fp8_blockwise_gather(out, quantizer, handle) out,
world_size,
interleaved_rowwise_scale_inv,
interleaved_columnwise_data,
)
return out, handle return out, async_handle
def _finish_all_gather_fp8_blockwise(
out: Float8BlockwiseQTensorStorage,
world_size: int,
interleaved_rowwise_scale_inv: Optional[torch.Tensor],
interleaved_columnwise_data: Optional[torch.Tensor],
) -> Float8BlockwiseQTensorStorage:
"""Post-process FP8 blockwise gather."""
# Fix interleaving in row-wise scales
if interleaved_rowwise_scale_inv is not None:
dim0 = out._rowwise_scale_inv.size(0)
view_in = interleaved_rowwise_scale_inv.view(world_size, dim0, -1)
view_out = out._rowwise_scale_inv.view(dim0, world_size, -1)
tex.swap_first_dims(view_in, out=view_out)
# Fix interleaving in column-wise data
if interleaved_columnwise_data is not None:
dim0 = out._columnwise_data.size(0)
view_in = interleaved_columnwise_data.view(world_size, dim0, -1)
view_out = out._columnwise_data.view(dim0, world_size, -1)
tex.swap_first_dims(view_in, out=view_out)
return out
def _swap_first_dims(tensor: torch.Tensor, world_size: int): def _swap_first_dims(tensor: torch.Tensor, world_size: int):
...@@ -1250,7 +1231,7 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int): ...@@ -1250,7 +1231,7 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int):
""" """
shape = tensor.shape shape = tensor.shape
assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave." assert len(shape) >= 2, "Wrong number of dimensions for fixing interleave."
first_dim = shape[0] first_dim = shape[0]
flattened_trailing = math.prod(shape[1:]) flattened_trailing = math.prod(shape[1:])
assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave."
...@@ -1681,7 +1662,7 @@ def gather_along_first_dim( ...@@ -1681,7 +1662,7 @@ def gather_along_first_dim(
if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance( if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance(
quantizer, Float8BlockQuantizer quantizer, Float8BlockQuantizer
): ):
return _all_gather_fp8_blockwise( return _start_all_gather_fp8_blockwise(
inp, inp,
process_group, process_group,
async_op=async_op, async_op=async_op,
...@@ -1719,10 +1700,6 @@ def gather_along_first_dim( ...@@ -1719,10 +1700,6 @@ def gather_along_first_dim(
) )
if isinstance(inp, QuantizedTensorStorage): if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
compact = _get_quantizer_format(quantizer)
_set_quantizer_format(quantizer, compact=False)
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=inp.dtype, dtype=inp.dtype,
...@@ -1731,7 +1708,6 @@ def gather_along_first_dim( ...@@ -1731,7 +1708,6 @@ def gather_along_first_dim(
) )
torch.distributed.all_gather_into_tensor(out, inp, group=process_group) torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out) out = quantizer(out)
_set_quantizer_format(quantizer, compact=compact)
return out, None return out, None
# Dequantize quantized tensor if not supported # Dequantize quantized tensor if not supported
......
...@@ -560,6 +560,8 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -560,6 +560,8 @@ def fill_userbuffers_buffer_for_all_gather(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, " "Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}" f"but got MXFP8 tensor with shape={tuple(local_shape)}"
) )
if local_tensor._with_gemm_swizzled_scales:
raise ValueError("Userbuffers assumes MXFP8 tensors have unswizzled scales")
local_scale_inv = ( local_scale_inv = (
local_tensor._rowwise_scale_inv local_tensor._rowwise_scale_inv
if with_rowwise_data if with_rowwise_data
...@@ -592,6 +594,7 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -592,6 +594,7 @@ def fill_userbuffers_buffer_for_all_gather(
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype, fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer, quantizer=quantizer,
with_gemm_swizzled_scales=False,
) )
return global_tensor, local_tensor return global_tensor, local_tensor
......
...@@ -720,13 +720,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -720,13 +720,9 @@ class GroupedLinear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
assert not self.tp_size > 1, (
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
...@@ -879,9 +875,12 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -879,9 +875,12 @@ class GroupedLinear(TransformerEngineBaseModule):
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear.""" """Customize quantizers based on current scaling recipe + linear."""
assert (
recipe.float8_current_scaling() assert not self.tp_size > 1, (
), "current scaling recipe quantizer customization here" "GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
if fwd: if fwd:
for i in range(self.num_gemms): for i in range(self.num_gemms):
# set configs about amax epsilon and power_2_scale # set configs about amax epsilon and power_2_scale
...@@ -954,9 +953,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -954,9 +953,9 @@ class GroupedLinear(TransformerEngineBaseModule):
] ]
for i in range(self.num_gemms) for i in range(self.num_gemms)
] ]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms): for i in range(self.num_gemms):
input_quantizers[i].internal = False input_quantizers[i].internal = True
input_quantizers[i].optimize_for_gemm = True
if torch.is_grad_enabled(): if torch.is_grad_enabled():
grad_output_quantizers = [ grad_output_quantizers = [
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
...@@ -966,6 +965,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -966,6 +965,7 @@ class GroupedLinear(TransformerEngineBaseModule):
] ]
for i in range(self.num_gemms): for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True grad_output_quantizers[i].internal = True
grad_output_quantizers[i].optimize_for_gemm = True
return ( return (
input_quantizers, input_quantizers,
weight_quantizers, weight_quantizers,
......
...@@ -64,7 +64,6 @@ from ..quantized_tensor import ( ...@@ -64,7 +64,6 @@ from ..quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..cpu_offload import ( from ..cpu_offload import (
is_cpu_offload_enabled, is_cpu_offload_enabled,
...@@ -253,8 +252,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -253,8 +252,6 @@ class _LayerNormLinear(torch.autograd.Function):
if fp8 or debug: if fp8 or debug:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, Float8BlockQuantizer):
input_quantizer.all_gather_usage = False
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
quantizer = None quantizer = None
...@@ -1409,15 +1406,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1409,15 +1406,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4(): elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe) self._customize_quantizers_nvfp4(fwd, recipe)
# elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
...@@ -1619,12 +1613,16 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1619,12 +1613,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True input_quantizer.internal = True
if not (self.parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
(weight_quantizer,) = self._get_weight_quantizers() (weight_quantizer,) = self._get_weight_quantizers()
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if is_grad_enabled: if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if not (self.parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
if fp8_grad: if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
...@@ -1808,14 +1806,3 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1808,14 +1806,3 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
return [weight_quantizer] return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
...@@ -431,8 +431,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -431,8 +431,6 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8 or debug: if fp8 or debug:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
fc1_input_quantizer.all_gather_usage = False
ln_out_total = fc1_input_quantizer(ln_out_total) ln_out_total = fc1_input_quantizer(ln_out_total)
else: else:
quantizer = None quantizer = None
...@@ -1964,15 +1962,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1964,15 +1962,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4(): elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe) self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
...@@ -2193,6 +2188,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2193,6 +2188,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.fp8 or self.fp8_calibration: if self.fp8 or self.fp8_calibration:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True fc1_input_quantizer.internal = True
if not self.sequence_parallel:
fc1_input_quantizer.optimize_for_gemm = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage( fc2_input_quantizer.set_usage(
rowwise=True, rowwise=True,
...@@ -2201,7 +2198,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2201,7 +2198,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
(MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer), (MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer),
), ),
) )
fc1_input_quantizer.internal = True fc2_input_quantizer.internal = True
fc2_input_quantizer.optimize_for_gemm = True
if fp8_output: if fp8_output:
fc2_output_quantizer = self.quantizers["scaling_fwd"][ fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT tex.FP8FwdTensors.GEMM2_OUTPUT
...@@ -2211,10 +2209,13 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2211,10 +2209,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2 tex.FP8BwdTensors.GRAD_OUTPUT2
] ]
fc2_grad_output_quantizer.internal = True fc2_grad_output_quantizer.internal = True
if not self.sequence_parallel:
fc2_grad_output_quantizer.optimize_for_gemm = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
] ]
fc1_grad_output_quantizer.internal = True fc1_grad_output_quantizer.internal = True
fc1_grad_output_quantizer.optimize_for_gemm = True
return ( return (
fc1_input_quantizer, fc1_input_quantizer,
...@@ -2467,22 +2468,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2467,22 +2468,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer.internal = True fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer] return [fc1_weight_quantizer, fc2_weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].all_gather_usage = True
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
......
...@@ -1313,15 +1313,12 @@ class Linear(TransformerEngineBaseModule): ...@@ -1313,15 +1313,12 @@ class Linear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # Recipe-specific quantizer configuration
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4(): elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe) self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init) super().reset_parameters(defer_init=defer_init)
...@@ -1489,12 +1486,16 @@ class Linear(TransformerEngineBaseModule): ...@@ -1489,12 +1486,16 @@ class Linear(TransformerEngineBaseModule):
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True input_quantizer.internal = True
if not (self.parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
(weight_quantizer,) = self._get_weight_quantizers() (weight_quantizer,) = self._get_weight_quantizers()
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if is_grad_enabled: if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if not (self.parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
if fp8_grad: if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
return ( return (
...@@ -1669,22 +1670,3 @@ class Linear(TransformerEngineBaseModule): ...@@ -1669,22 +1670,3 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
return [weight_quantizer] return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set compact for inp tensor X
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.parallel_mode == "row":
# set compact for grad_output tensor dY
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].all_gather_usage = True
...@@ -342,15 +342,21 @@ class BasicLinear(BasicOperation): ...@@ -342,15 +342,21 @@ class BasicLinear(BasicOperation):
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe) super().reset_recipe_state(recipe=recipe)
# Input/grad output quantizers use internal tensors # Configure input/grad output tensor
# Note: These tensors are only used internally. If there is no
# tensor-parallel communication, they are only used for GEMM.
input_quantizer = self.get_quantizer("forward", 0) input_quantizer = self.get_quantizer("forward", 0)
grad_output_quantizer = self.get_quantizer("backward", 0) grad_output_quantizer = self.get_quantizer("backward", 0)
if input_quantizer is not None: if input_quantizer is not None:
input_quantizer.internal = True input_quantizer.internal = True
if not (self.tensor_parallel_mode == "column" and self.sequence_parallel):
input_quantizer.optimize_for_gemm = True
if grad_output_quantizer is not None: if grad_output_quantizer is not None:
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if not (self.tensor_parallel_mode == "row" and self.sequence_parallel):
grad_output_quantizer.optimize_for_gemm = True
# Handle weight quantizer # Configure weight quantizer
# Note: This function may be called in base class constructor, # Note: This function may be called in base class constructor,
# before any basic linear attrs have been set. # before any basic linear attrs have been set.
weight_quantizer = self.get_quantizer("forward", 1) weight_quantizer = self.get_quantizer("forward", 1)
......
...@@ -292,6 +292,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -292,6 +292,7 @@ class UserbuffersBackwardLinear(FusedOperation):
rowwise=True, rowwise=True,
columnwise=with_columnwise, columnwise=with_columnwise,
) )
grad_output_quantizer.optimize_for_gemm = False
dy_local = grad_output_quantizer(dy_local) dy_local = grad_output_quantizer(dy_local)
else: else:
dy_local = maybe_dequantize(dy_local, dtype) dy_local = maybe_dequantize(dy_local, dtype)
......
...@@ -294,6 +294,7 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -294,6 +294,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
columnwise_scale_inv=None, columnwise_scale_inv=None,
quantizer=None, quantizer=None,
requires_grad=output.requires_grad, requires_grad=output.requires_grad,
with_gemm_swizzled_scales=False,
) )
ctx.save_for_backward(row_id_map, pad_offsets) ctx.save_for_backward(row_id_map, pad_offsets)
...@@ -504,6 +505,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -504,6 +505,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
columnwise_scale_inv=None, columnwise_scale_inv=None,
quantizer=None, quantizer=None,
requires_grad=act_grad.requires_grad, requires_grad=act_grad.requires_grad,
with_gemm_swizzled_scales=False,
) )
if not ctx.needs_input_grad[2]: if not ctx.needs_input_grad[2]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment