Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -20,4 +20,14 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( ...@@ -20,4 +20,14 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream()); force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, const py::object &dummy,
std::vector<std::vector<at::Tensor>> tensor_lists) {
NVTE_CHECK(dummy.is_none(), "No-op flag is not supported.");
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
nvte_multi_tensor_compute_scale_inv_e8m0_cuda(chunk_size, tensor_lists_ptr.data(), num_lists,
num_tensors, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
...@@ -84,14 +89,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -84,14 +89,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py); TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py);
TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor // Quantizer
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_nvte;
if (out.is_none()) {
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Choose implementation // Choose implementation
enum class Impl { enum class Impl {
...@@ -130,6 +129,19 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -130,6 +129,19 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
} }
} }
// Output tensor
TensorWrapper out_nvte;
if (out.is_none()) {
if (impl == Impl::FULLY_FUSED) {
// FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN
// kernel does not support GEMM swizzled scales
quantizer_cpp->optimize_for_gemm = false;
}
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Construct unquantized output tensor if needed // Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte; TensorWrapper unquantized_out_nvte;
py::object unquantized_out; py::object unquantized_out;
...@@ -294,6 +306,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -294,6 +306,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
const int sm_margin, const bool zero_centered_gamma) { const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
...@@ -308,14 +325,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -308,14 +325,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat)); at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor // Quantizer
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_nvte;
if (out.is_none()) {
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Choose implementation // Choose implementation
enum class Impl { enum class Impl {
...@@ -354,6 +365,19 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -354,6 +365,19 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
} }
} }
// Output tensor
TensorWrapper out_nvte;
if (out.is_none()) {
if (impl == Impl::FULLY_FUSED) {
// FP8 has no special logic to optimize for GEMM, MXFP8 cuDNN
// kernel does not support GEMM swizzled scales
quantizer_cpp->optimize_for_gemm = false;
}
std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else {
out_nvte = makeTransformerEngineTensor(out, quantizer);
}
// Construct unquantized output tensor if needed // Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte; TensorWrapper unquantized_out_nvte;
py::object unquantized_out; py::object unquantized_out;
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
m.def("split_quantize", &transformer_engine::pytorch::split_quantize, m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list")); py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false);
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM"); "Grouped GEMM");
#ifdef USE_ROCM #ifdef USE_ROCM
...@@ -296,10 +296,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -296,10 +296,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"), "Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"), py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>()); py::arg("out_dtype"), py::call_guard<py::gil_scoped_release>());
m.def("mxfp8_scaling_compute_partial_amax",
&transformer_engine::pytorch::mxfp8_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 mxfp8 scaling", py::arg("input"),
py::arg("amax_rowwise"), py::arg("amax_colwise"), py::arg("rows"), py::arg("cols"),
py::arg("start_offset"), py::call_guard<py::gil_scoped_release>());
m.def("mxfp8_scaling_partial_cast", &transformer_engine::pytorch::mxfp8_scaling_partial_cast,
"Partial cast from master weights for fp8 mxfp8 scaling", py::arg("input"),
py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"),
py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"),
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding,
"Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>()); "Fused Multi-tensor padding", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding,
"Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>()); "Fused Multi-tensor unpadding", py::call_guard<py::gil_scoped_release>());
m.def("swizzle_scales_for_gemm_", &transformer_engine::pytorch::inplace_swizzle_scale_for_gemm,
"Convert tensor block scales into GEMM swizzled format");
// attention kernels // attention kernels
m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd, m.def("fa_prepare_fwd", &transformer_engine::pytorch::fa_prepare_fwd,
...@@ -450,6 +462,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -450,6 +462,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_compute_scale_and_scale_inv", m.def("multi_tensor_compute_scale_and_scale_inv",
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda, &transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>()); "Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_inv_e8m0",
&transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda,
"Fused compute E8M0 scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Comm+GEMM Overlap // Comm+GEMM Overlap
m.def("bulk_overlap_ag_with_external_gemm", m.def("bulk_overlap_ag_with_external_gemm",
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -22,8 +22,8 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { ...@@ -22,8 +22,8 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
auto* amax_ptr = amax.data_ptr<float>(); auto* amax_ptr = amax.data_ptr<float>();
TensorWrapper fake_te_output( TensorWrapper fake_te_output(
nullptr, te_input.shape(), /*dptr=*/nullptr, te_input.shape(),
DType::kFloat8E4M3, // It doesn't matter because we only compute amax. DType::kFloat32, // It doesn't matter because we only compute amax.
amax_ptr); amax_ptr);
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "common.h"
#include "common/common.h"
#include "extensions.h"
#include "pybind.h"
#include "util.h"
namespace transformer_engine {
namespace pytorch {
namespace {
void reset_tensor_data(transformer_engine::TensorWrapper &tensor, bool rowwise, bool columnwise) {
NVTEShape shape;
shape.ndim = 1;
shape.data[0] = 0;
const transformer_engine::DType dtype = transformer_engine::DType::kFloat32;
if (rowwise) {
tensor.set_rowwise_data(nullptr, dtype, shape);
tensor.set_rowwise_scale_inv(nullptr, dtype, shape);
}
if (columnwise) {
tensor.set_columnwise_data(nullptr, dtype, shape);
tensor.set_columnwise_scale_inv(nullptr, dtype, shape);
}
}
} // namespace
std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm(
transformer_engine::TensorWrapper &tensor, bool rowwise_usage, bool columnwise_usage) {
// Return early if scale swizzling is not required
const auto scaling_mode = tensor.scaling_mode();
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return {std::nullopt, std::nullopt};
}
// Return early if scales are already swizzled
if (tensor.get_with_gemm_swizzled_scales()) {
return {std::nullopt, std::nullopt};
}
// CUDA stream
auto stream = at::cuda::getCurrentCUDAStream();
// Swizzle row-wise scales if needed
std::optional<at::Tensor> rowwise_scales_pyt;
if (rowwise_usage) {
// Buffer for unswizzled scales
const auto input_scales_nvte = tensor.get_rowwise_scale_inv();
void *input_scales_dptr = input_scales_nvte.data_ptr;
const NVTEShape input_scales_shape = input_scales_nvte.shape;
const auto scales_dtype = static_cast<DType>(input_scales_nvte.dtype);
// Allocate buffer for swizzled scales
const NVTEShape output_scales_shape = input_scales_shape;
rowwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false);
void *output_scales_dptr = getDataPtr(*rowwise_scales_pyt);
// Initialize TE tensors with scales
const auto data_nvte = tensor.get_rowwise_data();
const auto data_dtype = static_cast<DType>(data_nvte.dtype);
TensorWrapper input_nvte(scaling_mode);
input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_rowwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape);
TensorWrapper output_nvte(scaling_mode);
output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
output_nvte.set_with_gemm_swizzled_scales(true);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE(
{ nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); });
// Update tensor with swizzled scales
tensor.set_rowwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
}
// Swizzle column-wise scales if needed
std::optional<at::Tensor> columnwise_scales_pyt;
if (columnwise_usage) {
// Buffer for unswizzled scales
const auto input_scales_nvte = tensor.get_columnwise_scale_inv();
void *input_scales_dptr = input_scales_nvte.data_ptr;
const NVTEShape input_scales_shape = input_scales_nvte.shape;
const auto scales_dtype = static_cast<DType>(input_scales_nvte.dtype);
// Allocate buffer for swizzled scales
const NVTEShape output_scales_shape = input_scales_shape;
columnwise_scales_pyt = allocateSpace(input_scales_shape, scales_dtype, false);
void *output_scales_dptr = getDataPtr(*columnwise_scales_pyt);
// Initialize TE tensors with scales
const auto data_nvte = tensor.get_columnwise_data();
const auto data_dtype = static_cast<DType>(data_nvte.dtype);
TensorWrapper input_nvte(scaling_mode);
input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_columnwise_scale_inv(input_scales_dptr, scales_dtype, input_scales_shape);
TensorWrapper output_nvte(scaling_mode);
output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
output_nvte.set_with_gemm_swizzled_scales(true);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE(
{ nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); });
// Update tensor with swizzled scales
tensor.set_columnwise_scale_inv(output_scales_dptr, scales_dtype, output_scales_shape);
}
// Update tensor
reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage);
tensor.set_with_gemm_swizzled_scales(true);
return {std::move(rowwise_scales_pyt), std::move(columnwise_scales_pyt)};
}
std::optional<at::Tensor> multi_tensor_swizzle_scales_for_gemm(
std::vector<transformer_engine::TensorWrapper> &tensors, bool rowwise_usage,
bool columnwise_usage) {
// Checks and trivial cases
NVTE_CHECK(rowwise_usage != columnwise_usage,
"Expect exactly one of rowwise_usage=", rowwise_usage,
" and columnwise_usage=", columnwise_usage, ".");
if (tensors.empty()) {
return std::nullopt;
}
const auto scaling_mode = tensors.front().scaling_mode();
for (const auto &tensor : tensors) {
NVTE_CHECK(tensor.scaling_mode() == scaling_mode, "Tensors have different scaling modes");
}
// Return early if scale swizzling is not required
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return std::nullopt;
}
// Filter out tensors that already have swizzled scales
std::vector<TensorWrapper *> tensors_needing_swizzle;
for (auto &tensor : tensors) {
if (!tensor.get_with_gemm_swizzled_scales()) {
tensors_needing_swizzle.push_back(&tensor);
}
}
if (tensors_needing_swizzle.empty()) {
return std::nullopt;
}
// Determine buffer size needed for swizzled scales
std::vector<size_t> output_scales_offsets;
size_t output_scales_bytes = 0;
for (auto &tensor : tensors_needing_swizzle) {
const auto scales_nvte =
(rowwise_usage ? tensor->get_rowwise_scale_inv() : tensor->get_columnwise_scale_inv());
const auto &shape = scales_nvte.shape;
const auto dtype = static_cast<DType>(scales_nvte.dtype);
const auto dtype_bits = transformer_engine::pytorch::typeToNumBits(dtype);
const auto size = product(shape, 0, shape.ndim);
output_scales_bytes = roundup(output_scales_bytes, 16); // align to 16B
output_scales_offsets.push_back(output_scales_bytes);
output_scales_bytes += ceildiv(size * dtype_bits, 8);
}
// Allocate buffer for swizzled scales
auto output_scales_pyt = allocateSpace(std::vector<size_t>{output_scales_bytes},
transformer_engine::DType::kByte, false);
uint8_t *output_scales_dptr = reinterpret_cast<uint8_t *>(getDataPtr(output_scales_pyt));
// Construct TE tensors with only scales
std::vector<transformer_engine::TensorWrapper> inputs_nvte, outputs_nvte;
for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) {
auto &tensor = *tensors_needing_swizzle[i];
inputs_nvte.emplace_back(scaling_mode);
outputs_nvte.emplace_back(scaling_mode);
auto &input_nvte = inputs_nvte.back();
auto &output_nvte = outputs_nvte.back();
output_nvte.set_with_gemm_swizzled_scales(true);
if (rowwise_usage) {
const auto data_nvte = tensor.get_rowwise_data();
const auto scales_nvte = tensor.get_rowwise_scale_inv();
const auto data_dtype = static_cast<transformer_engine::DType>(data_nvte.dtype);
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
input_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_rowwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape);
output_nvte.set_rowwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
} else {
const auto data_nvte = tensor.get_columnwise_data();
const auto scales_nvte = tensor.get_columnwise_scale_inv();
const auto data_dtype = static_cast<transformer_engine::DType>(data_nvte.dtype);
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
input_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
input_nvte.set_columnwise_scale_inv(scales_nvte.data_ptr, scales_dtype, scales_nvte.shape);
output_nvte.set_columnwise_data(nullptr, data_dtype, data_nvte.shape);
output_nvte.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i],
scales_dtype, scales_nvte.shape);
}
}
// Pack raw NVTETensors into vectors
std::vector<NVTETensor> inputs_nvte_raw, outputs_nvte_raw;
for (auto &tensor : inputs_nvte) {
inputs_nvte_raw.emplace_back(tensor.data());
}
for (auto &tensor : outputs_nvte) {
outputs_nvte_raw.emplace_back(tensor.data());
}
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_tensor_swizzle_scaling_factors(inputs_nvte_raw.data(), outputs_nvte_raw.data(),
inputs_nvte_raw.size(),
at::cuda::getCurrentCUDAStream());
});
// Update tensors with swizzled scales
for (size_t i = 0; i < tensors_needing_swizzle.size(); ++i) {
auto &tensor = *tensors_needing_swizzle[i];
reset_tensor_data(tensor, !rowwise_usage, !columnwise_usage);
tensor.set_with_gemm_swizzled_scales(true);
if (rowwise_usage) {
auto scales_nvte = outputs_nvte[i].get_rowwise_scale_inv();
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
tensor.set_rowwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
} else {
auto scales_nvte = outputs_nvte[i].get_columnwise_scale_inv();
const auto scales_dtype = static_cast<transformer_engine::DType>(scales_nvte.dtype);
tensor.set_columnwise_scale_inv(output_scales_dptr + output_scales_offsets[i], scales_dtype,
scales_nvte.shape);
}
}
return std::move(output_scales_pyt);
}
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input,
bool rowwise) {
// Check input tensor
const NVTEScalingMode scaling_mode = input.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (size_t i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (size_t i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
data_shape.ndim = 2;
// Recreate input tensor with rowwise usage
transformer_engine::TensorWrapper input_cu(scaling_mode);
input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
const NVTEBasicTensor scale_inv =
rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
input_cu.set_rowwise_scale_inv(
scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);
// Create output tensor
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
// Output swizzled mxfp8 scaling factor dimensions
const size_t swizzled_scale_inv_first_dim = ceildiv(data_flat_first_dim, 128) * 128;
const size_t swizzled_scale_inv_last_dim = ceildiv(data_flat_last_dim, 128) * 4;
// Allocate memory for swizzled mxfp8 scaling factors
at::Tensor swizzled_scale_inv =
allocateSpace(std::vector<size_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim},
transformer_engine::DType::kByte, false);
// Set rowwise scaling factors on output
void *const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{};
swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
swizzled_scale_inv_shape.ndim = 2;
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
swizzled_scale_inv_shape);
output_cu.set_with_gemm_swizzled_scales(true);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
NVTE_SCOPED_GIL_RELEASE({
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
});
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input = std::move(output_cu);
return swizzled_scale_inv;
}
void inplace_swizzle_scale_for_gemm(py::handle &tensor) {
// Convert Python tensor to C++ tensor
auto tensor_nvte = makeTransformerEngineTensor(tensor, py::none());
// Return early if scale swizzling is not required
const auto scaling_mode = tensor_nvte.scaling_mode();
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
// Tensor format requires scale swizzling
break;
case NVTE_INVALID_SCALING:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
default:
// Tensor format does not require scale swizzling for GEMM
return;
}
// Return early if scales are already swizzled
if (tensor_nvte.get_with_gemm_swizzled_scales()) {
return;
}
// Check what scaling factors the tensor contains
auto is_empty = [](const NVTEBasicTensor &t) -> bool {
return t.shape.ndim == 1 && t.shape.data[0] == 0;
};
const bool has_rowwise_scales = !is_empty(tensor_nvte.get_rowwise_scale_inv());
const bool has_columnwise_scales = !is_empty(tensor_nvte.get_columnwise_scale_inv());
// Swizzle scaling factors
auto [rowwise_scales, columnwise_scales] =
swizzle_scales_for_gemm(tensor_nvte, has_rowwise_scales, has_columnwise_scales);
// Update Python tensor with swizzled scales
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
if (has_rowwise_scales) {
tensor.attr("_rowwise_scale_inv") = rowwise_scales;
}
if (has_columnwise_scales) {
tensor.attr("_columnwise_scale_inv") = columnwise_scales;
}
tensor.attr("_with_gemm_swizzled_scales") = true;
break;
case NVTE_NVFP4_1D_SCALING:
if (has_rowwise_scales) {
tensor.attr("_rowwise_scale_inv") = rowwise_scales;
}
if (has_columnwise_scales) {
tensor.attr("_columnwise_scale_inv") = columnwise_scales;
}
tensor.attr("_with_gemm_swizzled_scales") = true;
break;
default:
NVTE_ERROR("Invalid scaling mode for swizzling scaling factors.");
}
}
} // namespace pytorch
} // namespace transformer_engine
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -52,10 +52,12 @@ Quantizer::Quantizer(const py::handle& quantizer) { ...@@ -52,10 +52,12 @@ Quantizer::Quantizer(const py::handle& quantizer) {
this->rowwise_usage = true; this->rowwise_usage = true;
this->columnwise_usage = true; this->columnwise_usage = true;
this->internal = false; this->internal = false;
this->optimize_for_gemm = false;
} else { } else {
this->rowwise_usage = quantizer.attr("rowwise_usage").cast<bool>(); this->rowwise_usage = quantizer.attr("rowwise_usage").cast<bool>();
this->columnwise_usage = quantizer.attr("columnwise_usage").cast<bool>(); this->columnwise_usage = quantizer.attr("columnwise_usage").cast<bool>();
this->internal = quantizer.attr("internal").cast<bool>(); this->internal = quantizer.attr("internal").cast<bool>();
this->optimize_for_gemm = quantizer.attr("optimize_for_gemm").cast<bool>();
this->quantizer = quantizer; this->quantizer = quantizer;
} }
} }
...@@ -555,7 +557,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti ...@@ -555,7 +557,6 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>(); this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim."); "Unsupported block scaling dim.");
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
} }
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {}
...@@ -575,10 +576,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -575,10 +576,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) { if (rowwise_usage) {
data_rowwise = at::empty(torch_shape, opts); data_rowwise = at::empty(torch_shape, opts);
auto scale_shape = get_scale_shape(shape, false); auto scale_shape = get_scale_shape(shape, false);
...@@ -597,7 +594,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -597,7 +594,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape); columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) { if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size()); torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size()); columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
...@@ -606,13 +602,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -606,13 +602,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
torch_columnwise_shape.push_back(torch_shape[i]); torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]); columnwise_shape.push_back(shape[i]);
} }
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
}
} }
auto scale_shape = get_scale_shape(shape, true); auto scale_shape = get_scale_shape(shape, true);
size_t sinv0 = scale_shape[0]; size_t sinv0 = scale_shape[0];
...@@ -635,7 +624,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -635,7 +624,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); "is_2D_scaled"_a = (block_scaling_dim == 2));
} else { } else {
py::handle Float8BlockwiseQTensorClass( py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass)); reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
...@@ -643,8 +632,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -643,8 +632,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
"data_format"_a = data_format);
} }
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
...@@ -654,6 +642,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -654,6 +642,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
py::object tensor) const { py::object tensor) const {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>(); const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
const bool with_gemm_swizzled_scales = true;
// Extract buffers from Python tensor // Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> { auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
...@@ -675,13 +664,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -675,13 +664,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t> { auto get_columnwise_shape = [&columnwise_data]() -> std::vector<size_t> {
if (!columnwise_data) { if (!columnwise_data) {
return std::vector<size_t>(); return std::vector<size_t>();
} }
if (all_gather_usage) {
return getTensorShape(*columnwise_data);
}
std::vector<size_t> shape = getTensorShape(*columnwise_data); std::vector<size_t> shape = getTensorShape(*columnwise_data);
std::vector<size_t> shape_transposed(shape.size()); std::vector<size_t> shape_transposed(shape.size());
for (size_t i = 0; i + 1 < shape.size(); ++i) { for (size_t i = 0; i + 1 < shape.size(); ++i) {
...@@ -696,12 +682,12 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -696,12 +682,12 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
if (rowwise_data) { if (rowwise_data) {
shape = getTensorShape(*rowwise_data); shape = getTensorShape(*rowwise_data);
if (columnwise_data) { if (columnwise_data) {
auto expected_shape = get_columnwise_shape(all_gather_usage); auto expected_shape = get_columnwise_shape();
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
") and column-wise data (shape=", expected_shape, ") do not match"); ") and column-wise data (shape=", expected_shape, ") do not match");
} }
} else { } else {
shape = get_columnwise_shape(all_gather_usage); shape = get_columnwise_shape();
} }
std::vector<int64_t> torch_shape; std::vector<int64_t> torch_shape;
for (auto s : shape) { for (auto s : shape) {
...@@ -738,7 +724,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -738,7 +724,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
std::vector<size_t> columnwise_shape; std::vector<size_t> columnwise_shape;
std::vector<int64_t> torch_columnwise_shape; std::vector<int64_t> torch_columnwise_shape;
if (torch_shape.size() > 0) { if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size()); torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size()); columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
...@@ -747,13 +732,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -747,13 +732,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
torch_columnwise_shape.push_back(torch_shape[i]); torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]); columnwise_shape.push_back(shape[i]);
} }
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
}
} }
if (!columnwise_data) { if (!columnwise_data) {
columnwise_data = at::empty(torch_columnwise_shape, opts); columnwise_data = at::empty(torch_columnwise_shape, opts);
...@@ -798,6 +776,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te ...@@ -798,6 +776,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
} }
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
set_quantization_params(&ret); set_quantization_params(&ret);
return {std::move(ret), std::move(tensor)}; return {std::move(ret), std::move(tensor)};
} }
...@@ -813,9 +792,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o ...@@ -813,9 +792,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o
} }
quant_config.set_force_pow_2_scales(force_pow_2_scales); quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon); quant_config.set_amax_epsilon(amax_epsilon);
if (all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream());
}); });
...@@ -832,10 +808,6 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size ...@@ -832,10 +808,6 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t m_dim = numel / k_dim; size_t m_dim = numel / k_dim;
size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len()); size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len());
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
std::vector<size_t> scale_shape; std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise; bool rowwise_usage = !columnwise;
...@@ -845,23 +817,14 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size ...@@ -845,23 +817,14 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t sinv0 = 0; size_t sinv0 = 0;
size_t sinv1 = 0; size_t sinv1 = 0;
if (block_scaling_dim == 2) { if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now sinv0 = ceildiv(m_dim, kBlockLen);
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, sinv1 = roundup(ceildiv(k_dim, kBlockLen), 4);
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) { } else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv0 = ceildiv(k_dim, kBlockLen);
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); sinv1 = roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else { } else {
NVTE_CHECK(false, NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise." "Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ", "Expected 1 or 2. Got ",
block_scaling_dim); block_scaling_dim);
...@@ -872,21 +835,13 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size ...@@ -872,21 +835,13 @@ std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size
size_t sinv0 = 0; size_t sinv0 = 0;
size_t sinv1 = 0; size_t sinv1 = 0;
if (block_scaling_dim == 2) { if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now sinv0 = ceildiv(k_dim, kBlockLen);
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, sinv1 = roundup(ceildiv(m_dim, kBlockLen), 4);
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) { } else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT sinv0 = ceildiv(m_dim, kBlockLen);
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; sinv1 = roundup(k_dim, 4);
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else { } else {
NVTE_CHECK(false, NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise." "Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ", "Expected 1 or 2. Got ",
block_scaling_dim); block_scaling_dim);
...@@ -906,6 +861,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -906,6 +861,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
DType dtype) const { DType dtype) const {
using namespace pybind11::literals; using namespace pybind11::literals;
// Scaling factor format
const bool with_gemm_swizzled_scales = this->optimize_for_gemm;
// Tensor dimensions // Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end()); const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1; size_t flat_first_dim = 1;
...@@ -951,19 +909,17 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -951,19 +909,17 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
py::object out_py; py::object out_py;
if (internal) { if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass)); py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass));
out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, out_py = MXFP8TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py,
"columnwise_data"_a = columnwise_data_py, columnwise_scale_inv_py, this->dtype, this->quantizer,
"rowwise_scale_inv"_a = rowwise_scale_inv_py, with_gemm_swizzled_scales);
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
} else { } else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass)); py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), out_py = MXFP8TensorClass(
"rowwise_data"_a = rowwise_data_py, "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"columnwise_data"_a = columnwise_data_py, "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, "fp8_dtype"_a = this->dtype,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales);
} }
// Construct C++ MXFP8 tensor // Construct C++ MXFP8 tensor
...@@ -978,6 +934,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -978,6 +934,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0,
columnwise_scale_inv_shape); columnwise_scale_inv_shape);
} }
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp); this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
...@@ -987,6 +944,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor( ...@@ -987,6 +944,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
py::object tensor) const { py::object tensor) const {
NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor.");
// Scaling factor format
const bool with_gemm_swizzled_scales = this->optimize_for_gemm;
// Extract buffers from Python tensor // Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> { auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name); auto attr_py = tensor.attr(name);
...@@ -1070,6 +1030,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor( ...@@ -1070,6 +1030,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
// Coerce other attrs // Coerce other attrs
tensor.attr("_fp8_dtype") = dtype; tensor.attr("_fp8_dtype") = dtype;
tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales;
// Construct C++ MXFP8 tensor // Construct C++ MXFP8 tensor
TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING);
...@@ -1083,6 +1044,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor( ...@@ -1083,6 +1044,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::convert_and_update_tensor(
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0,
getTensorShape(*columnwise_scale_inv)); getTensorShape(*columnwise_scale_inv));
} }
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp); this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)}; return {std::move(out_cpp), std::move(tensor)};
...@@ -1173,6 +1135,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1173,6 +1135,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
DType dtype) const { DType dtype) const {
using namespace pybind11::literals; using namespace pybind11::literals;
// Scaling factor format
const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) self->optimize_for_gemm
// Tensor dimensions // Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end()); const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1; size_t flat_first_dim = 1;
...@@ -1235,12 +1200,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1235,12 +1200,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
py::object out_py; py::object out_py;
if (internal) { if (internal) {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass)); py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass));
out_py = NVFP4TensorClass( out_py = NVFP4TensorClass(rowwise_data_py, rowwise_scale_inv_py, columnwise_data_py,
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, columnwise_scale_inv_py, amax_rowwise_py, amax_columnwise_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py, this->dtype, this->quantizer, with_gemm_swizzled_scales);
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
} else { } else {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass)); py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass));
out_py = NVFP4TensorClass( out_py = NVFP4TensorClass(
...@@ -1249,7 +1211,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1249,7 +1211,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
"rowwise_scale_inv"_a = rowwise_scale_inv_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer); "quantizer"_a = this->quantizer, "with_gemm_swizzled_scales"_a = with_gemm_swizzled_scales);
} }
// Construct C++ tensor // Construct C++ tensor
...@@ -1272,6 +1234,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1272,6 +1234,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
} }
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp); this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
...@@ -1301,6 +1264,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor( ...@@ -1301,6 +1264,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
py::object tensor) const { py::object tensor) const {
NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor.");
// Scaling factor format
const bool with_gemm_swizzled_scales = false; // TODO (tmoon) Enable with optimize_for_gemm
// Extract buffers from Python tensor // Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> { auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name); auto attr_py = tensor.attr(name);
...@@ -1438,6 +1404,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor( ...@@ -1438,6 +1404,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32,
std::vector<size_t>{1}); std::vector<size_t>{1});
} }
out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
this->set_quantization_params(&out_cpp); this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)}; return {std::move(out_cpp), std::move(tensor)};
...@@ -1468,20 +1435,40 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1468,20 +1435,40 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
} }
size_t cols = input.size(input.ndim() - 1); size_t cols = input.size(input.ndim() - 1);
// Stochastic rounding
// When both rowwise and columnwise quantization are used with RHT,
// we need separate RNG states for each to ensure they use different random numbers.
TensorWrapper te_rng_state; TensorWrapper te_rng_state;
TensorWrapper te_rng_state_columnwise;
QuantizationConfigWrapper quant_config_columnwise;
const bool need_separate_columnwise_rng =
this->stochastic_rounding && this->with_rht && this->columnwise_usage;
if (this->stochastic_rounding) { if (this->stochastic_rounding) {
const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
// Generate RNG state for rowwise quantization
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, opts); auto rng_state = torch::empty({2}, opts);
philox_unpack(philox_args, static_cast<int64_t*>(rng_state.data_ptr())); philox_unpack(philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
te_rng_state = makeTransformerEngineTensor(rng_state); te_rng_state = makeTransformerEngineTensor(rng_state);
quant_config.set_rng_state(te_rng_state.data()); quant_config.set_rng_state(te_rng_state.data());
// Generate separate RNG state for columnwise quantization
if (need_separate_columnwise_rng) {
at::PhiloxCudaState philox_args_columnwise = init_philox_state(gen, rng_elts_per_thread);
auto rng_state_columnwise = torch::empty({2}, opts);
philox_unpack(philox_args_columnwise, static_cast<int64_t*>(rng_state_columnwise.data_ptr()));
te_rng_state_columnwise = makeTransformerEngineTensor(rng_state_columnwise);
quant_config_columnwise.set_stochastic_rounding(true);
quant_config_columnwise.set_rng_state(te_rng_state_columnwise.data());
}
} }
// Restriction for the RHT cast fusion kernel. // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT
bool eligible_for_rht_cast_fusion = bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
...@@ -1609,6 +1596,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1609,6 +1596,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
static_cast<DType>(out_columnwise_amax.dtype), static_cast<DType>(out_columnwise_amax.dtype),
out_columnwise_amax.shape); out_columnwise_amax.shape);
// Use separate RNG state for columnwise to ensure different random numbers than rowwise
auto& columnwise_quant_config =
need_separate_columnwise_rng ? quant_config_columnwise : quant_config;
if (!eligible_for_rht_cast_fusion) { if (!eligible_for_rht_cast_fusion) {
// Invoking fallback RHT kernel. // Invoking fallback RHT kernel.
...@@ -1637,7 +1628,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1637,7 +1628,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
// Quantize kernel will treat everything as rowwise input/output, which is // Quantize kernel will treat everything as rowwise input/output, which is
// intended. // intended.
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), columnwise_quant_config,
stream);
}); });
} else { } else {
// RHT cast fusion kernel. // RHT cast fusion kernel.
...@@ -1648,8 +1640,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou ...@@ -1648,8 +1640,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou
NVTE_CHECK(false, "Not only supported for nvte_hadamard_transform_cast_fusion_columnwise"); NVTE_CHECK(false, "Not only supported for nvte_hadamard_transform_cast_fusion_columnwise");
#else #else
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_hadamard_transform_cast_fusion_columnwise( nvte_hadamard_transform_cast_fusion_columnwise(input.data(), out_transpose.data(),
input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); rht_matrix_nvte.data(),
columnwise_quant_config, stream);
}); });
#endif #endif
} }
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -55,8 +55,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer ...@@ -55,8 +55,9 @@ TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer
TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) {
auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast<bool>();
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor."); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor.");
...@@ -78,6 +79,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) ...@@ -78,6 +79,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
getTensorShape(scale_inv)); getTensorShape(scale_inv));
} }
// Scale layout
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
// Quantizer state // Quantizer state
quantizer->set_quantization_params(&ret); quantizer->set_quantization_params(&ret);
...@@ -93,6 +97,7 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer ...@@ -93,6 +97,7 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
// Row-wise data
if (rowwise_usage) { if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>(); const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>(); const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
...@@ -102,6 +107,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer ...@@ -102,6 +107,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
} }
// Column-wise data
if (columnwise_usage) { if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>(); const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>(); const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
...@@ -112,7 +119,10 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer ...@@ -112,7 +119,10 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
} }
// Quantizer state
quantizer->set_quantization_params(&ret); quantizer->set_quantization_params(&ret);
return ret; return ret;
} }
...@@ -121,8 +131,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) ...@@ -121,8 +131,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING); auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast<bool>();
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor.");
...@@ -150,6 +161,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) ...@@ -150,6 +161,9 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer)
getTensorShape(amax_columnwise)); getTensorShape(amax_columnwise));
} }
// Scale layout
ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
// Quantizer state // Quantizer state
quantizer->set_quantization_params(&ret); quantizer->set_quantization_params(&ret);
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "util.h"
#include "common.h"
#include "common/common.h"
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING &&
input.scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt;
}
NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8,
"4-bit or 8-bit input required for swizzling scaling factors.");
const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING;
NVTEBasicTensor scale_inv;
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = input.shape();
scale_inv = input.get_rowwise_scale_inv();
} else {
nvte_input_shape = input.get_columnwise_data().shape;
scale_inv = input.get_columnwise_scale_inv();
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape.");
// Allocate memory for swizzled output.
auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
std::vector<int64_t> scale_inv_shape_int;
for (size_t i = 0; i < scale_inv_shape.size(); ++i) {
scale_inv_shape_int.push_back(static_cast<int64_t>(scale_inv_shape[i]));
}
auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options);
void* scale_inv_dptr = scale_inv.data_ptr;
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
transformer_engine::TensorWrapper input_cu(input.scaling_mode());
transformer_engine::TensorWrapper output_cu(input.scaling_mode());
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
} else {
input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
}
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
}
return swizzled_scale_inv;
}
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper>& tensors, bool rowwise) {
using namespace transformer_engine::pytorch;
if (tensors.empty()) {
return std::nullopt;
}
bool all_same_scaling_mode = std::all_of(
tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) {
return val.scaling_mode() == tensors.front().scaling_mode();
});
NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same.");
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING &&
tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt;
}
const auto scaling_mode = tensors.front().scaling_mode();
const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors;
// Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
std::vector<std::vector<size_t>> scale_inv_shapes;
std::vector<void*> scale_inv_dptrs;
size_t buffer_size = 0;
std::vector<size_t> scale_inv_offsets;
constexpr size_t scale_elem_size = 1;
for (auto& tensor : tensors) {
NVTEBasicTensor scale_inv;
if (rowwise) {
scale_inv = tensor.get_rowwise_scale_inv();
} else {
scale_inv = tensor.get_columnwise_scale_inv();
}
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_inv_offsets.push_back(buffer_size);
buffer_size += product(scale_inv_shape) * scale_elem_size;
scale_inv_shapes.emplace_back(scale_inv_shape);
scale_inv_dptrs.push_back(scale_inv.data_ptr);
}
// Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
// auto input_shape = nvte_shape_to_vector(tensor.shape());
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = tensor.shape();
} else {
nvte_input_shape = tensor.get_columnwise_data().shape;
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(scaling_mode);
transformer_engine::TensorWrapper output_cu(scaling_mode);
if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
} else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
scale_inv_shapes[i]);
}
input_tensors.emplace_back(input_cu.data());
output_tensors.emplace_back(output_cu.data());
wrappers.emplace_back(std::move(input_cu));
wrappers.emplace_back(std::move(output_cu));
}
// Launch kernel
nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(),
input_tensors.size(), at::cuda::getCurrentCUDAStream());
return buffer;
}
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
using transformer_engine::DIVUP;
// Check input tensor
const NVTEScalingMode scaling_mode = input.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (int i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (int i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
data_shape.ndim = 2;
// Recreate input tensor with rowwise usage
transformer_engine::TensorWrapper input_cu(scaling_mode);
input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
const NVTEBasicTensor scale_inv =
rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
input_cu.set_rowwise_scale_inv(
scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);
// Create output tensor
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
// Output swizzled mxfp8 scaling factor dimensions
const size_t swizzled_scale_inv_first_dim = DIVUP<size_t>(data_flat_first_dim, 128) * 128;
const size_t swizzled_scale_inv_last_dim = DIVUP<size_t>(data_flat_last_dim, 128) * 4;
// Allocate memory for swizzled mxfp8 scaling factors
const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
at::Tensor swizzled_scale_inv = at::empty(
std::vector<int64_t>{static_cast<int64_t>(swizzled_scale_inv_first_dim), static_cast<int64_t>(swizzled_scale_inv_last_dim)}, options);
// Set rowwise scaling factors on output
void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{};
swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
swizzled_scale_inv_shape.ndim = 2;
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
swizzled_scale_inv_shape);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input = std::move(output_cu);
return swizzled_scale_inv;
}
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -10,33 +10,44 @@ ...@@ -10,33 +10,44 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <optional> #include <optional>
#include <tuple>
#include <vector>
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
/*! \brief Swizzle the scaling factor of the input tensor. namespace transformer_engine {
namespace pytorch {
/*! \brief Convert tensor block scales into GEMM swizzled format.
* *
* The returned swizzled scaling factor tensor should be kept alive during the GEMM. * The returned swizzled scales should be kept alive during the GEMM.
*/ */
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input, std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm(
bool rowwise); TensorWrapper& tensor, bool rowwise_usage, bool columnwise_usage);
/*! \brief Swizzle the scaling factor of the input tensors. /*! \brief Convert multiple tensor block scales into GEMM swizzled format.
* *
* The returned swizzled scaling factor tensors should be kept alive during the GEMMs. * The returned swizzled scales should be kept alive during the GEMMs.
*/ */
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( std::optional<at::Tensor> multi_tensor_swizzle_scales_for_gemm(std::vector<TensorWrapper>& tensors,
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise); bool rowwise_usage,
bool columnwise_usage);
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. /*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
* *
* If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid * If rowwise==false, the columnwise data will be reinterpreted as
* transposing it in memory. Due to differences in how block scaling and mxfp8 store data, * rowwise data to avoid transposing it in memory. Due to differences
* this requires the calling code to treat the output tensor as having been tranposed in this case. * in how block scaling and mxfp8 store data, this requires the
* calling code to treat the output tensor as having been transposed
* in this case.
* *
* Returns the swizzled scaling factor of the converted mxfp8 tensor. * Returns the swizzled scaling factor of the converted mxfp8 tensor.
* The returned swizzled scaling factor tensor should be kept alive during the GEMM. * The returned swizzled scaling factor tensor should be kept alive
* during the GEMM.
*/ */
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise);
bool rowwise);
} // namespace pytorch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment