Unverified Commit 3f5b4754 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[Core][PyTorch] NVFP4 recipe (#2177)



* Add NVFP4 recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add MathDx dependency to GitHub builds
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suggestions from GitHub Copilot
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move 2x shape logic from core to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compilation errors with CUDA 12.1
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* SM 70 is not supported in CUDA 13
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Typo
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Revert "Move 2x shape logic from core to PyTorch"

This reverts commit f8b2a2d0111d9af690b43bb98ae448d9a430a185.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Added dequantize kernel for FP4
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 support with fusible ops

Use logical tensor dims for PyTorch NVFP4 tensors. Temporarily add unfused dequantize impl. Fix bug where NVFP4 recipe was not configurable.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix logic for 2x shapes and move to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CG test model config
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug NVFP4 tensor size function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Proper handling of the RNG state
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Test SR properly
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix workspace size for GEMM heuristic.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compile error in C++ NVFP4 test

Some some numeric errors when blocks are all zero.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix distrbuted test problem shape
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* proper assert dim for low precision AG TP
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up duplicated code in nvfp4_utils.cuh
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pylint: disable=unused-argument
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* `nvte_cublas_gemm_v2` to take alpha pointer (#12)

* make nvte_cublas_gemm_v2 to take alpha/beta pointers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* users are expected to pass a valid C_tensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* typos
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* API to have const float* alpha
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Minor tweaks

Support arbitrary beta scales. Increase workspace to be aligned to 128 bytes.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug IMA with alpha pointer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Support fused amax kernels with NVFP4 quantization
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable fused amax with cuDNN LayerNorm kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 cases to distributed tests for TE ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change assert to NVTE_CHECK in the hadamard cast fusion
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix compile error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use global thread IDs for Philox subsequences
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add shape checks for NVFP4 cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Do not fuse amax if cuDNN normalization is forced by envvar
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dfeef1a2
...@@ -8,179 +8,269 @@ ...@@ -8,179 +8,269 @@
#include "common.h" #include "common.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine {
namespace pytorch {
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)> namespace {
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t),
const at::Tensor& input, py::handle quantizer,
int shape_divisor = 1) {
init_extension(); init_extension();
// Input tensor // Input tensor
auto input_tensor = input.contiguous(); auto input_tensor = input.contiguous();
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct output tensor // Construct output tensor
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape = input_cpp.shape(); const auto input_shape = input_nvte.shape();
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim); std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
output_shape.back() /= shape_divisor; output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
// Compute activation // Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) { detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly impl = Impl::FULLY_FUSED;
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation in high-precision fused together with amax, then quantize. impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get()); auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
NVTE_SCOPED_GIL_RELEASE( if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); // Post-RHT amax is handled within NVFP4 quantizer
quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); impl = Impl::UNFUSED;
} else { } else {
// Compute activation in high-precision, then quantize impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); }
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); // Perform compute
quantizer_cpp->quantize(temp_cpp, out_cpp); auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Compute activation in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
quantizer_cpp->quantize(temp_nvte, out_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation directly
{
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); });
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
} }
return out_py; return out_py;
} }
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)> py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor,
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, cudaStream_t),
py::handle quantizer) { const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) {
init_extension(); init_extension();
// Grad output and input tensors // Grad output and input tensors
auto grad_output_tensor = grad_output.contiguous(); auto grad_output_tensor = grad_output.contiguous();
auto input_tensor = input.contiguous(); auto input_tensor = input.contiguous();
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct grad input tensor // Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape(); const auto input_shape_te = input_nvte.shape();
const std::vector<size_t> input_shape(input_shape_te.data, const std::vector<size_t> input_shape(input_shape_te.data,
input_shape_te.data + input_shape_te.ndim); input_shape_te.data + input_shape_te.ndim);
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
// Compute activation backward // Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) { detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly impl = Impl::FULLY_FUSED;
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation backward in high-precision fused together with amax, then quantize. impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get()); } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_SCOPED_GIL_RELEASE({ NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
at::cuda::getCurrentCUDAStream()); // Post-RHT amax is handled within NVFP4 quantizer
}); impl = Impl::UNFUSED;
quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); } else {
} else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
// Compute activation backward in high-precision, then quantize }
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); }
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), // Perform compute
at::cuda::getCurrentCUDAStream()); auto stream = at::cuda::getCurrentCUDAStream();
}); switch (impl) {
quantizer_cpp->quantize(temp_cpp, grad_input_cpp); case Impl::UNFUSED:
// Compute activation backward in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation backward directly
{
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream);
});
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); });
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); });
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
} }
return grad_input_py; return grad_input_py;
} }
/* GELU and variants*/ } // namespace
/* GELU and variants */
py::object gelu(const at::Tensor& input, py::handle quantizer) { py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_gelu>(input, quantizer); return activation_forward(nvte_gelu, input, quantizer);
} }
py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgelu>(grad, input, quantizer); return activation_backward(nvte_dgelu, grad, input, quantizer);
} }
py::object geglu(const at::Tensor& input, py::handle quantizer) { py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu>(input, quantizer, 2); return activation_forward(nvte_geglu, input, quantizer, 2);
} }
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer); return activation_backward(nvte_dgeglu, grad, input, quantizer);
} }
py::object qgelu(const at::Tensor& input, py::handle quantizer) { py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgelu>(input, quantizer); return activation_forward(nvte_qgelu, input, quantizer);
} }
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer); return activation_backward(nvte_dqgelu, grad, input, quantizer);
} }
py::object qgeglu(const at::Tensor& input, py::handle quantizer) { py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgeglu>(input, quantizer, 2); return activation_forward(nvte_qgeglu, input, quantizer, 2);
} }
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer); return activation_backward(nvte_dqgeglu, grad, input, quantizer);
} }
/* ReLU and variants*/ /* ReLU and variants */
py::object relu(const at::Tensor& input, py::handle quantizer) { py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer); return activation_forward(nvte_relu, input, quantizer);
} }
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer); return activation_backward(nvte_drelu, grad, input, quantizer);
} }
py::object reglu(const at::Tensor& input, py::handle quantizer) { py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_reglu>(input, quantizer, 2); return activation_forward(nvte_reglu, input, quantizer, 2);
} }
py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dreglu>(grad, input, quantizer); return activation_backward(nvte_dreglu, grad, input, quantizer);
} }
py::object srelu(const at::Tensor& input, py::handle quantizer) { py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_srelu>(input, quantizer); return activation_forward(nvte_srelu, input, quantizer);
} }
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer); return activation_backward(nvte_dsrelu, grad, input, quantizer);
} }
py::object sreglu(const at::Tensor& input, py::handle quantizer) { py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_sreglu>(input, quantizer, 2); return activation_forward(nvte_sreglu, input, quantizer, 2);
} }
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsreglu>(grad, input, quantizer); return activation_backward(nvte_dsreglu, grad, input, quantizer);
} }
/* Silu and variants*/ /* Silu and variants */
py::object silu(const at::Tensor& input, py::handle quantizer) { py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_silu>(input, quantizer); return activation_forward(nvte_silu, input, quantizer);
} }
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsilu>(grad, input, quantizer); return activation_backward(nvte_dsilu, grad, input, quantizer);
} }
py::object swiglu(const at::Tensor& input, py::handle quantizer) { py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2); return activation_forward(nvte_swiglu, input, quantizer, 2);
} }
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer); return activation_backward(nvte_dswiglu, grad, input, quantizer);
} }
} // namespace transformer_engine::pytorch
} // namespace pytorch
} // namespace transformer_engine
...@@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s ...@@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
{ nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); });
} }
void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
} // namespace } // namespace
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
...@@ -198,7 +182,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -198,7 +182,7 @@ std::vector<py::object> fused_attn_fwd(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr())); philox_unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state); auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
......
...@@ -122,13 +122,27 @@ std::vector<py::object> dact_dbias( ...@@ -122,13 +122,27 @@ std::vector<py::object> dact_dbias(
} }
// Choose implementation // Choose implementation
enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX }; enum class Impl {
UNFUSED,
FUSED_DACT_DBIAS_QUANTIZE,
FUSED_DACT_AMAX_FP8,
FUSED_DACT_AMAX_NVFP4
};
Impl impl = Impl::UNFUSED; Impl impl = Impl::UNFUSED;
if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || if (detail::IsFloat8Quantizers(quantizer_py.ptr()) ||
detail::IsMXFP8Quantizers(quantizer_py.ptr())) { detail::IsMXFP8Quantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; impl = Impl::FUSED_DACT_DBIAS_QUANTIZE;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_AMAX; impl = Impl::FUSED_DACT_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
impl = Impl::FUSED_DACT_AMAX_NVFP4;
}
} }
// Perform compute // Perform compute
...@@ -172,20 +186,38 @@ std::vector<py::object> dact_dbias( ...@@ -172,20 +186,38 @@ std::vector<py::object> dact_dbias(
}); });
break; break;
} }
case Impl::FUSED_DACT_AMAX: case Impl::FUSED_DACT_AMAX_FP8:
// Fused dact-amax kernel, unfused dbias and quantize // Fused dact-amax kernel, unfused dbias and FP8 quantize
{ {
auto *quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get()); auto *fp8_quantizer_cpp =
NVTE_CHECK(quantizer_cpp_cs != nullptr, dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr,
"Invalid quantizer for fused dact-amax kernel impl"); "Invalid quantizer for fused dact-amax kernel impl");
auto [temp_nvte, temp_py] = auto [temp_nvte, temp_py] =
quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype); fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
});
const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
break;
}
case Impl::FUSED_DACT_AMAX_NVFP4:
// Fused dact-amax kernel, unfused dbias and NVFP4 quantize
{
auto *nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer *>(quantizer_cpp.get()); // Already checked cast is valid
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr,
"Invalid quantizer for fused dact-amax kernel impl");
auto [temp_nvte, temp_py] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(
grad_input_nvte, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
}); });
const auto temp_torch = temp_py.cast<at::Tensor>(); const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte); nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
break; break;
} }
default: default:
......
...@@ -213,6 +213,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -213,6 +213,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const int sm_count = transformer_engine::cuda::sm_count(device_id); const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count); int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Construct GEMM config
transformer_engine::MatmulConfigWrapper config;
if (grad) {
config.set_dbias_tensor(bias_tensor.data());
config.set_with_dgelu_epilogue(gelu);
} else {
config.set_bias_tensor(bias_tensor.data());
config.set_with_gelu_epilogue(gelu);
}
config.set_epilogue_aux_tensor(te_pre_gelu_out.data());
config.set_use_split_accumulator(use_split_accumulator);
config.set_sm_count(num_math_sms);
// Keep the swizzled scaling factor tensors alive during the GEMM. // Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list; std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream(); auto main_stream = at::cuda::getCurrentCUDAStream();
...@@ -276,10 +289,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -276,10 +289,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else { } else {
// Launch GEMM // Launch GEMM
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), nvte_cublas_gemm_v2(transa, transb, &alpha, A_tensor.data(), B_tensor.data(), &beta.value(),
bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, out_tensor.data(), out_tensor.data(), te_workspace.data(), config,
te_workspace.data(), alpha, *beta, use_split_accumulator, main_stream);
num_math_sms, main_stream);
}); });
} }
} else { } else {
......
...@@ -66,67 +66,102 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -66,67 +66,102 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none);
TensorWrapper bias_cu; TensorWrapper bias_nvte;
if (bias.has_value()) { if (bias.has_value()) {
bias_cu = makeTransformerEngineTensor(*bias); bias_nvte = makeTransformerEngineTensor(*bias);
} }
// Tensor dimensions // Tensor dimensions
const size_t N = static_cast<size_t>(input_cu.size(0)); const auto shape = nvte_shape_to_vector(input_nvte.shape());
const size_t H = static_cast<size_t>(input_cu.size(1)); const auto outer_size = product(shape) / shape.back();
const std::vector<size_t> size = {N, H}; const auto inner_size = shape.back();
// Tensors to save for backward pass // Tensors to save for backward pass
at::Tensor mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor mu_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
at::Tensor rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
TensorWrapper mu_cu = makeTransformerEngineTensor(mu); TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py);
TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor // Output tensor
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_cu; TensorWrapper out_nvte;
if (out.is_none()) { if (out.is_none()) {
std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else { } else {
out_cu = makeTransformerEngineTensor(out, quantizer); out_nvte = makeTransformerEngineTensor(out, quantizer);
} }
// Determine whether to avoid fused kernel // Choose implementation
bool force_unfused_kernel = true; enum class Impl {
if (quantizer.is_none()) { // Compute norm in high precision, then quantize
// No need for separate quantization step if output is unquantized UNFUSED,
force_unfused_kernel = false; // Compute norm directly
} else if (IsFloat8Quantizers(quantizer.ptr())) { FULLY_FUSED,
// Always used fused kernel for FP8 delayed scaling // Compute norm and amax in high precision, then quantize to FP8
force_unfused_kernel = false; FUSED_NORM_AMAX_FP8,
// Compute norm and amax in high precision, then quantize to NVFP4
FUSED_NORM_AMAX_NVFP4
};
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) {
impl = Impl::FULLY_FUSED;
} else if (IsMXFP8Quantizers(quantizer.ptr())) { } else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 &&
// cuDNN MXFP8 kernel requires full tile inner_size % 128 == 0) {
force_unfused_kernel = N % 128 != 0 || H % 128 != 0; // cuDNN MXFP8 kernel requires full 128x128 tiles
impl = Impl::FULLY_FUSED;
}
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
impl = Impl::FUSED_NORM_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else if (!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// TE kernel supports amax output
impl = Impl::FUSED_NORM_AMAX_NVFP4;
} }
} }
TensorWrapper unquantized_out_cu;
// Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte;
py::object unquantized_out; py::object unquantized_out;
if (force_unfused_kernel) { TensorWrapper *kernel_out_nvte = &out_nvte;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && switch (impl) {
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { case Impl::UNFUSED: {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none}; NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
case Impl::FUSED_NORM_AMAX_FP8: {
auto fp8_quantizer_cpp = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
std::tie(unquantized_out_nvte, unquantized_out) =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
case Impl::FUSED_NORM_AMAX_NVFP4: {
auto nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
std::tie(unquantized_out_nvte, unquantized_out) =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
default: {
} }
} }
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size // Query workspace size
TensorWrapper workspace; TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps,
mu_cu.data(), rsigma_cu.data(), workspace.data(), kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
}); });
...@@ -138,24 +173,31 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -138,24 +173,31 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Launch kernel // Launch kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps,
mu_cu.data(), rsigma_cu.data(), workspace.data(), kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
}); });
// Quantize output if using unfused kernel // Quantize output if needed
if (force_unfused_kernel) { switch (impl) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && case Impl::UNFUSED: {
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { quantizer_cpp->quantize(unquantized_out_nvte, out_nvte);
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); } break;
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); case Impl::FUSED_NORM_AMAX_FP8: {
} else { auto fp8_quantizer_cpp = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
my_quantizer->quantize(unquantized_out_cu, out_cu); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte);
} break;
case Impl::FUSED_NORM_AMAX_NVFP4: {
auto nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte);
} break;
default: {
} }
} }
return {out, py::cast(mu), py::cast(rsigma)}; return {out, py::cast(mu_py), py::cast(rsigma_py)};
} }
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
...@@ -254,61 +296,95 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -254,61 +296,95 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none);
// Tensor dimensions // Tensor dimensions
const size_t N = static_cast<size_t>(input_cu.shape().data[0]); const auto shape = nvte_shape_to_vector(input_nvte.shape());
const size_t H = static_cast<size_t>(input_cu.shape().data[1]); const auto outer_size = product(shape) / shape.back();
const std::vector<size_t> size = {N, H}; const auto inner_size = shape.back();
// Tensors to save for backward pass // Tensors to save for backward pass
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
auto rsigma_cu = makeTransformerEngineTensor(rsigma); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor // Output tensor
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_cu; TensorWrapper out_nvte;
if (out.is_none()) { if (out.is_none()) {
std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else { } else {
out_cu = makeTransformerEngineTensor(out, quantizer); out_nvte = makeTransformerEngineTensor(out, quantizer);
} }
// Determine whether to avoid fused kernel // Choose implementation
bool force_unfused_kernel = true; enum class Impl {
if (quantizer.is_none()) { // Compute norm in high precision, then quantize
// No need for separate quantization step if output is unquantized UNFUSED,
force_unfused_kernel = false; // Compute norm directly
} else if (IsFloat8Quantizers(quantizer.ptr())) { FULLY_FUSED,
// Always used fused kernel for FP8 delayed scaling // Compute norm and amax in high precision, then quantize to FP8
force_unfused_kernel = false; FUSED_NORM_AMAX_FP8,
// Compute norm and amax in high precision, then quantize to NVFP4
FUSED_NORM_AMAX_NVFP4
};
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) {
impl = Impl::FULLY_FUSED;
} else if (IsMXFP8Quantizers(quantizer.ptr())) { } else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 &&
// cuDNN MXFP8 kernel requires full tile inner_size % 128 == 0) {
force_unfused_kernel = N % 128 != 0 || H % 128 != 0; // cuDNN MXFP8 kernel requires full 128x128 tiles
impl = Impl::FULLY_FUSED;
}
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
impl = Impl::FUSED_NORM_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else if (!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// TE kernel supports amax output
impl = Impl::FUSED_NORM_AMAX_NVFP4;
} }
} }
TensorWrapper unquantized_out_cu;
// Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte;
py::object unquantized_out; py::object unquantized_out;
if (force_unfused_kernel) { TensorWrapper *kernel_out_nvte = &out_nvte;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && switch (impl) {
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { case Impl::UNFUSED: {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none}; NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
case Impl::FUSED_NORM_AMAX_FP8: {
auto fp8_quantizer_cpp = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
std::tie(unquantized_out_nvte, unquantized_out) =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
case Impl::FUSED_NORM_AMAX_NVFP4: {
auto nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
std::tie(unquantized_out_nvte, unquantized_out) =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
default: {
} }
} }
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size // Query workspace size
TensorWrapper workspace; TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(),
workspace.data(), rsigma_nvte.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
}); });
...@@ -320,24 +396,30 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -320,24 +396,30 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Launch kernel // Launch kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(),
workspace.data(), rsigma_nvte.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
}); });
// Quantize output if using unfused kernel // Quantize output if needed
if (force_unfused_kernel) { switch (impl) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && case Impl::UNFUSED: {
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { quantizer_cpp->quantize(unquantized_out_nvte, out_nvte);
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); } break;
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); case Impl::FUSED_NORM_AMAX_FP8: {
} else { auto fp8_quantizer_cpp = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
my_quantizer->quantize(unquantized_out_cu, out_cu); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte);
} break;
case Impl::FUSED_NORM_AMAX_NVFP4: {
auto nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte);
} break;
default: {
} }
} }
return {out, py::none(), py::cast(rsigma)}; return {out, py::none(), py::cast(rsigma_py)};
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -32,6 +32,9 @@ PyTypeObject *MXFP8QuantizerClass = nullptr; ...@@ -32,6 +32,9 @@ PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorBasePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
void init_float8_extension() { void init_float8_extension() {
if (Float8TensorPythonClass) return; if (Float8TensorPythonClass) return;
...@@ -86,10 +89,26 @@ void init_float8blockwise_extension() { ...@@ -86,10 +89,26 @@ void init_float8blockwise_extension() {
"Internal error: could not initialize pyTorch float8blockwise extension."); "Internal error: could not initialize pyTorch float8blockwise extension.");
} }
void init_nvfp4_extensions() {
if (NVFP4TensorPythonClass) return;
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
NVFP4TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor"));
auto nvfp4_base_module =
py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base");
NVFP4TensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorBase"));
NVTE_CHECK(NVFP4TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch NVFP4 extension.");
}
void init_extension() { void init_extension() {
init_float8_extension(); init_float8_extension();
init_mxfp8_extension(); init_mxfp8_extension();
init_float8blockwise_extension(); init_float8blockwise_extension();
init_nvfp4_extensions();
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
......
...@@ -40,13 +40,12 @@ extern PyTypeObject *MXFP8QuantizerClass; ...@@ -40,13 +40,12 @@ extern PyTypeObject *MXFP8QuantizerClass;
extern PyTypeObject *Float8BlockwiseQTensorPythonClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass;
extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass;
extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *Float8BlockwiseQuantizerClass;
extern PyTypeObject *NVFP4TensorPythonClass;
extern PyTypeObject *NVFP4TensorBasePythonClass;
extern PyTypeObject *NVFP4QuantizerClass;
void init_extension(); void init_extension();
void init_float8_extension();
void init_mxfp8_extension();
namespace detail { namespace detail {
inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; }
...@@ -69,11 +68,17 @@ inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { ...@@ -69,11 +68,17 @@ inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; return Py_TYPE(obj) == Float8BlockwiseQuantizerClass;
} }
inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; }
inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { inline bool IsFloat8BlockwiseQTensor(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass ||
Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass;
} }
inline bool IsNVFP4Tensor(PyObject *obj) {
return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorBasePythonClass;
}
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer);
template <typename T> template <typename T>
...@@ -88,6 +93,8 @@ std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params); ...@@ -88,6 +93,8 @@ std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor,
Quantizer *quantization_params); Quantizer *quantization_params);
TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer);
inline bool IsFloatingPointType(at::ScalarType type) { inline bool IsFloatingPointType(at::ScalarType type) {
return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; return type == at::kFloat || type == at::kHalf || type == at::kBFloat16;
} }
...@@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = { ...@@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = {
std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor,
CreateQuantizer<MXFP8Quantizer>), CreateQuantizer<MXFP8Quantizer>),
std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers,
NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer<Float8BlockQuantizer>)}; NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer<Float8BlockQuantizer>),
std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor,
CreateQuantizer<NVFP4Quantizer>)};
} // namespace detail } // namespace detail
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
......
...@@ -31,8 +31,20 @@ std::vector<T> make_transpose_shape(const std::vector<S>& shape) { ...@@ -31,8 +31,20 @@ std::vector<T> make_transpose_shape(const std::vector<S>& shape) {
return ret; return ret;
} }
/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */
template <typename T = size_t>
std::vector<T> convert_shape_for_fp4(const std::vector<T>& shape) {
std::vector<T> ret;
for (size_t i = 0; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
ret.push_back(shape.back() / 2);
return ret;
}
} // namespace } // namespace
constexpr size_t NVFP4_BLOCK_SIZE = 16;
constexpr size_t MXFP8_BLOCK_SIZE = 32; constexpr size_t MXFP8_BLOCK_SIZE = 32;
Quantizer::Quantizer(const py::handle& quantizer) { Quantizer::Quantizer(const py::handle& quantizer) {
...@@ -376,8 +388,9 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -376,8 +388,9 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
} }
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_hp_tensor_with_amax( std::pair<TensorWrapper, py::object>
const std::vector<size_t>& shape, DType dtype) { Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector<size_t>& shape,
DType dtype) {
amax.zero_(); amax.zero_();
auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype);
out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
...@@ -899,7 +912,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -899,7 +912,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
} }
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")"); " (got shape=", shape, ")");
const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false);
const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true);
...@@ -1095,7 +1108,7 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s ...@@ -1095,7 +1108,7 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
auto last_dim = shape.back(); auto last_dim = shape.back();
NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")"); " (got shape=", shape, ")");
std::vector<size_t> scale_shape; std::vector<size_t> scale_shape;
...@@ -1116,4 +1129,573 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s ...@@ -1116,4 +1129,573 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
return scale_shape; return scale_shape;
} }
NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
this->with_rht = quantizer.attr("with_rht").cast<bool>();
this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast<bool>();
this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast<bool>();
this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast<bool>();
// Get amax reduction group if needed for NVFP4 AG
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
if (with_amax_reduction) {
auto group = quantizer.attr("_canonicalized_amax_reduction_group")();
NVTE_CHECK(!group.is_none(), "NVFP4Quantizer could not canonicalize amax reduction group");
amax_reduction_group = group.cast<c10::intrusive_ptr<dist_group_type>>();
}
this->with_amax_reduction = with_amax_reduction;
this->amax_reduction_group = amax_reduction_group;
this->rht_matrix_random_sign_mask_t = quantizer.attr("rht_matrix_random_sign_mask_t").cast<int>();
this->rht_matrix = quantizer.attr("rht_matrix").cast<at::Tensor>();
}
void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const {
// set dtype for rowwise and columnwise data in tensor wrapper
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(this->dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(this->dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
using namespace pybind11::literals;
// Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
}
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ",
NVFP4_BLOCK_SIZE, " (got shape=", shape, ")");
NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0,
"NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE,
" (got shape=", shape, ")");
const auto rowwise_scale_inv_shape = get_scale_shape(shape, false);
const auto columnwise_scale_inv_shape = get_scale_shape(shape, true);
// Allocate tensors
at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise;
at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise;
const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
if (rowwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts);
rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts);
amax_rowwise = at::empty({1}, bit32_tensor_opts);
}
if (columnwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(),
columnwise_scale_inv_shape.end());
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std::vector<int64_t> shape_int64_2d = {static_cast<int64_t>(flat_first_dim),
static_cast<int64_t>(flat_last_dim)};
const auto transpose_shape_int64 = make_transpose_shape<int64_t>(shape_int64_2d);
columnwise_data_tensor =
at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts);
columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts);
amax_columnwise = at::empty({1}, bit32_tensor_opts);
}
// Convert tensors to Python
auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object {
return need_cast ? py::cast(tensor) : py::none();
};
auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage);
auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage);
auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage);
auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage);
auto amax_rowwise_py = py_cast(amax_rowwise, rowwise_usage);
auto amax_columnwise_py = py_cast(amax_columnwise, columnwise_usage);
// Construct Python NVFP4 tensor
py::object out_py;
if (internal) {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorBasePythonClass));
out_py = NVFP4TensorClass(
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
} else {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass));
out_py = NVFP4TensorClass(
"shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
}
// Construct C++ tensor
TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3,
rowwise_scale_inv_shape);
out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (columnwise_usage) {
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std::vector<size_t> shape_2d = {flat_first_dim, flat_last_dim};
auto col_data_shape_fp4 = make_transpose_shape<size_t>(shape_2d);
out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1,
col_data_shape_fp4);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3,
columnwise_scale_inv_shape);
out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_unquantized_tensor_with_amax(
TensorWrapper& quantized_tensor, DType dtype) {
// Construct tensor
auto shape = convertShape(quantized_tensor.shape());
auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype);
// Register amax pointer from quantized tensor
void* amax_ptr = quantized_tensor.amax();
if (amax_ptr == nullptr) {
amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr;
}
NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor.");
out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1});
// Zero out amax
NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream()));
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor.");
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
auto amax_rowwise = get_tensor("_amax_rowwise");
auto amax_columnwise = get_tensor("_amax_columnwise");
NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data.");
// Tensor dimensions, shape means original shape
std::vector<size_t> shape;
if (columnwise_data) {
shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true);
if (rowwise_data) {
auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
}
} else { // Already checked columnwise_data_tensor == true
shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
}
size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
}
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, false);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
if (!amax_rowwise) {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
amax_rowwise = at::empty({1}, opts);
tensor.attr("_amax_rowwise") = *amax_rowwise;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
if (amax_rowwise) {
amax_rowwise.reset();
tensor.attr("_amax_rowwise") = py::none();
}
}
// Coerce column-wise data
if (columnwise_usage) {
if (!columnwise_data) {
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std::vector<int64_t> shape_int64_2d = {static_cast<int64_t>(flat_first_dim),
static_cast<int64_t>(flat_last_dim)};
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto transpose_shape_int64 = make_transpose_shape<int64_t>(shape_int64_2d);
columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, true);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
if (!amax_columnwise) {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
amax_columnwise = at::zeros({1}, opts);
tensor.attr("_amax_columnwise") = *amax_columnwise;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
if (amax_columnwise) {
amax_columnwise.reset();
tensor.attr("_amax_columnwise") = py::none();
}
}
// Construct C++ tensor
TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3,
getTensorShape(*rowwise_scale_inv));
out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (columnwise_usage) {
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std::vector<size_t> shape_2d = {flat_first_dim, flat_last_dim};
auto col_data_shape_fp4 = make_transpose_shape<size_t>(shape_2d);
out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1,
col_data_shape_fp4);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3,
getTensorShape(*columnwise_scale_inv));
out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag,
bool compute_amax) {
// Nothing to be done if input is empty
if (input.numel() == 0) {
return;
}
auto stream = at::cuda::getCurrentCUDAStream();
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization);
quant_config.set_stochastic_rounding(this->stochastic_rounding);
// We only need RHT for columnwise usage.
// flat first dim and last dim for multi dimensional input
size_t rows = 1;
for (size_t i = 0; i < input.ndim() - 1; ++i) {
rows *= input.size(i);
}
size_t cols = input.size(input.ndim() - 1);
TensorWrapper te_rng_state;
if (this->stochastic_rounding) {
const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
auto rng_state = torch::empty({2}, opts);
philox_unpack(philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
te_rng_state = makeTransformerEngineTensor(rng_state);
quant_config.set_rng_state(te_rng_state.data());
}
// Restriction for the RHT cast fusion kernel.
bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
// Compute amax.
if (this->with_rht) {
if (input.dtype() != DType::kBFloat16) {
NVTE_CHECK(false, "RHT is only supported for bfloat16 input");
}
if (this->with_post_rht_amax) {
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for RHT(input.t)
NVTE_SCOPED_GIL_RELEASE({
nvte_hadamard_transform_amax(input.data(), out.data(), 0,
this->rht_matrix_random_sign_mask_t, stream);
});
} else {
// raise error since it's not supported yet
NVTE_CHECK(false, "Pre-RHT amax is not supported yet");
}
} else { // Without RHT
if (compute_amax) {
// Amax pointers
auto rowwise_amax_ptr = out.get_amax().data_ptr;
auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr;
void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr;
NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer");
// Compute amax of input tensor
out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1});
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); });
out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector<size_t>{1});
// Make sure row-wise and column-wise amaxes match
if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, stream));
}
if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, stream));
}
}
}
// amax reduction
if (this->with_amax_reduction) {
std::vector<at::Tensor> amax_tensors;
// push amax tensors inside if they need to be reduced
auto make_amax_tensor = [](void* data_ptr) {
return at::from_blob(
data_ptr, std::vector<int64_t>{1},
[](void*) {}, // deleter doing nothing since it doesn't own the data
at::device(at::kCUDA).dtype(torch::kFloat32));
};
if (rowwise_usage) {
amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr));
}
if (columnwise_usage) {
amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr));
}
c10d::AllreduceCoalescedOptions opts;
opts.reduceOp = c10d::ReduceOp::MAX;
NVTE_SCOPED_GIL_RELEASE(
{ this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); });
}
if (this->with_rht) {
if (rowwise_usage) {
// For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise
TensorWrapper out_identity(out.scaling_mode());
auto out_identity_data = out.get_rowwise_data();
auto out_identity_scale_inv = out.get_rowwise_scale_inv();
auto out_identity_amax = out.get_amax();
out_identity.set_rowwise_data(out_identity_data.data_ptr,
static_cast<DType>(out_identity_data.dtype),
out_identity_data.shape);
out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr,
static_cast<DType>(out_identity_scale_inv.dtype),
out_identity_scale_inv.shape);
out_identity.set_amax(out_identity_amax.data_ptr, static_cast<DType>(out_identity_amax.dtype),
out_identity_amax.shape);
NVTE_SCOPED_GIL_RELEASE(
{ nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); });
}
if (columnwise_usage) {
// Get the output columnwise data, scale_inv, and amax
auto out_columnwise_data = out.get_columnwise_data();
auto out_columnwise_scale_inv = out.get_columnwise_scale_inv();
// NOTE: should already be populated.
auto out_columnwise_amax = out.get_columnwise_amax();
// Create a wrapper for the columnwise output, as the rowwise output.
// The reason is due to the input `rht_output_t` is already in the transposed layout.
// Thus, we only need a rowwise quantization to generate the columnwise output.
TensorWrapper out_transpose(out.scaling_mode());
// Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail
// need to convert the shape to 2D here
auto colwise_data_shape = out_columnwise_data.shape;
std::vector<size_t> colwise_data_shape_2d;
// shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte
// the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again
// so the multiple 2 get cancelled out
colwise_data_shape_2d.push_back(colwise_data_shape.data[0]);
size_t last_dim = 1;
for (size_t i = 1; i < colwise_data_shape.ndim; ++i) {
last_dim *= colwise_data_shape.data[i];
}
colwise_data_shape_2d.push_back(last_dim);
out_transpose.set_rowwise_data(out_columnwise_data.data_ptr,
static_cast<DType>(out_columnwise_data.dtype),
colwise_data_shape_2d);
out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr,
static_cast<DType>(out_columnwise_scale_inv.dtype),
out_columnwise_scale_inv.shape);
out_transpose.set_amax(out_columnwise_amax.data_ptr,
static_cast<DType>(out_columnwise_amax.dtype),
out_columnwise_amax.shape);
if (!eligible_for_rht_cast_fusion) {
// Invoking fallback RHT kernel.
// If using RHT, then amax will be computed in the RHT step
// If not using RHT, then amax will be computed based on input x
at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout
// This wrapper is going to be passed as input to the quantization kernel.
TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs
rht_output_t =
allocateTorchTensor(static_cast<int>(cols), static_cast<int>(rows), input.dtype());
// NOTE (frsun): This is non-intuitive, we are writing the
// result of transposed RHT to the output of rowwise.
rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(),
std::vector<size_t>{cols, rows});
NVTE_SCOPED_GIL_RELEASE({
// Perform the RHT(input.t), and write to rht_output_cpp.columnwise.
nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0,
this->rht_matrix_random_sign_mask_t, stream);
});
// Quantize kernel will treat everything as rowwise input/output, which is
// intended.
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream);
});
} else {
// RHT cast fusion kernel.
NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0,
"RHT matrix is not set");
auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix);
NVTE_SCOPED_GIL_RELEASE({
nvte_hadamard_transform_cast_fusion_columnwise(
input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream);
});
}
}
} else {
NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });
}
}
void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
this->quantize_impl(input, out, noop_flag, true);
}
void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) {
// Update output tensor amaxes with input tensor amax
auto input_amax_ptr = input.amax();
auto output_rowwise_amax_ptr = out.get_amax().data_ptr;
auto output_columnwise_amax_ptr = out.get_columnwise_amax().data_ptr;
NVTE_CHECK(input_amax_ptr != nullptr ||
(output_rowwise_amax_ptr == nullptr && output_columnwise_amax_ptr == nullptr),
"Input tensor does not have pre-computed amax");
if (input_amax_ptr != output_rowwise_amax_ptr && input_amax_ptr != nullptr &&
output_rowwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(output_rowwise_amax_ptr, input_amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream()));
}
if (input_amax_ptr != output_columnwise_amax_ptr && input_amax_ptr != nullptr &&
output_columnwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream()));
}
input.set_amax(nullptr, DType::kFloat32, input.defaultShape);
// Perform quantization
this->quantize_impl(input, out, std::nullopt, false);
}
std::vector<size_t> NVFP4Quantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
for (auto s : shape) {
numel *= s;
}
auto last_dim = shape.back();
auto flat_first_dim = numel / last_dim;
NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ",
NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")");
NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0,
"NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE,
" (got shape=", shape, ")");
std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise;
if (rowwise_usage) {
// rowwise scaling factor shape
size_t sinv0 = roundup(flat_first_dim, 128);
size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4);
scale_shape = {sinv0, sinv1};
} else {
// columnwise scaling factor shape
size_t sinv0 = roundup(last_dim, 128);
size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4);
scale_shape = {sinv0, sinv1};
}
return scale_shape;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer ...@@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
return ret; return ret;
} }
TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) {
const DType dtype = tensor.attr("_fp4_dtype").cast<DType>();
auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor.");
// Row-scaled data
if (rowwise_usage) {
const auto &data = tensor.attr("_rowwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), dtype,
convert_shape_back_from_fp4(getTensorShape(data), false));
ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv));
ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise));
}
// Column-scaled data
if (columnwise_usage) {
const auto &data = tensor.attr("_columnwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast<at::Tensor>();
ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1,
convert_shape_back_from_fp4(getTensorShape(data), false));
ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3,
getTensorShape(scale_inv));
ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32,
getTensorShape(amax_columnwise));
}
// Quantizer state
quantizer->set_quantization_params(&ret);
return ret;
}
} // namespace detail } // namespace detail
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -14,22 +14,31 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -14,22 +14,31 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
if (input.scaling_mode() == NVTE_INVALID_SCALING) { if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle."); NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING &&
input.scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt; return std::nullopt;
} }
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8,
"4-bit or 8-bit input required for swizzling scaling factors.");
const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING;
NVTEBasicTensor scale_inv; NVTEBasicTensor scale_inv;
NVTEShape nvte_input_shape;
if (rowwise) { if (rowwise) {
nvte_input_shape = input.shape();
scale_inv = input.get_rowwise_scale_inv(); scale_inv = input.get_rowwise_scale_inv();
} else { } else {
nvte_input_shape = input.get_columnwise_data().shape;
scale_inv = input.get_columnwise_scale_inv(); scale_inv = input.get_columnwise_scale_inv();
} }
auto input_shape = nvte_shape_to_vector(input.shape()); auto input_shape = nvte_shape_to_vector(nvte_input_shape);
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape.");
// Allocate memory for swizzled output. // Allocate memory for swizzled output.
auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
std::vector<int64_t> scale_inv_shape_int; std::vector<int64_t> scale_inv_shape_int;
...@@ -41,36 +50,34 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -41,36 +50,34 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
// Reconstruct input only to avoid swizzling both directions if not needed. // Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant. // The specific dtype used is irrelevant, just needs to be correct bits.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper input_cu(input.scaling_mode());
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper output_cu(input.scaling_mode());
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
if (rowwise) { if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
scale_inv_shape); output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} else { } else {
input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
input_shape); input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
scale_inv_shape); output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shape);
} }
// Launch kernel // Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) { if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
scale_inv_shape);
} else { } else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
scale_inv_shape);
} }
return swizzled_scale_inv; return swizzled_scale_inv;
......
...@@ -39,11 +39,14 @@ from .constants import dist_group_type ...@@ -39,11 +39,14 @@ from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.nvfp4_tensor import NVFP4Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor.quantized_tensor import QuantizedTensorBase, QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.nvfp4_tensor_base import NVFP4TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .triton.pad import pad_columnwise_scale_inv
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
...@@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise( ...@@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise(
return out, handle return out, handle
def _swap_first_dims(tensor: torch.Tensor, world_size: int):
"""
Swap first 2 dimensions of a tensor to fix interleaved
data format after gathering transposed data.
For more than 2 dimensions, we squash the trailing dimensions,
instead of the first few dimensions, that's because the shape
passed in this function is already transposed.
"""
shape = tensor.shape
assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave."
first_dim = shape[0]
flattened_trailing = math.prod(shape[1:])
assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave."
tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing)
tensor = tex.swap_first_dims(tensor, out=None)
return tensor.reshape(first_dim // world_size, flattened_trailing * world_size)
def _post_process_nvfp4_gather(
out: NVFP4TensorBase,
columnwise_data_interleaved: torch.Tensor,
columnwise_scale_inv_interleaved: torch.Tensor,
world_size: int,
handle: Optional[torch.distributed.Work] = None,
) -> NVFP4TensorBase:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
# Fix the interleaved transposed data from gathering along first dim.
out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
@dataclass
class _NVFP4AllGatherAsyncHandle:
"""Handle for asynchronous NVFP4 all-gather."""
output: NVFP4TensorBase
columnwise_data_interleaved: torch.Tensor
columnwise_scale_inv_interleaved: torch.Tensor
world_size: int
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_nvfp4_gather(
self.output,
self.columnwise_data_interleaved,
self.columnwise_scale_inv_interleaved,
self.world_size,
)
self._synchronized = True
def _all_gather_nvfp4(
inp: torch.Tensor,
process_group: dist_group_type,
*,
async_op: bool = False,
quantizer: NVFP4Quantizer,
out_shape: Optional[list[int]] = None,
) -> tuple[NVFP4TensorBase, Optional[torch.distributed.Work]]:
"""All-gather NVFP4 tensor along first dimension."""
# Input tensor attributes
in_shape: Iterable[int] = None
in_shape_t: Iterable[int] = None
device: torch.device
dtype: torch.dtype
# Construct packed shapes for input and input_t.
if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorBase):
# High-precision tensor.
in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size())
in_shape_t = NVFP4Quantizer.convert_shape_for_fp4(
NVFP4Quantizer.get_columnwise_shape(inp.size())
)
device = inp.device
dtype = inp.dtype
elif isinstance(inp, NVFP4TensorBase):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
if inp._columnwise_data is not None:
in_shape_t = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = torch.bfloat16
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or NVFP4TensorBase, "
f"found {inp.__class__.__name__})"
)
assert in_shape is not None or in_shape_t is not None, "No data found."
world_size = get_distributed_world_size(process_group)
if out_shape is None:
out_shape = [in_shape[0] * world_size] + in_shape[1:]
# For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to NVFP4.
if (
not isinstance(inp, NVFP4TensorBase)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty(
out_shape,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
return out, None
# Cast input tensor to NVFP4 with required data
if not isinstance(inp, NVFP4TensorBase):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to NVFP4."
)
inp = quantizer(inp.dequantize())
# Construct NVFP4 output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives for gathering data and scale inverses.
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as gather_coalescing_manager:
# Gather NVFP4 data for row-wise usage
if quantizer.rowwise_usage:
# Remove padding from NVFP4 scale-inverses
assert in_shape is not None, "Shape not found."
in_scale_inv = inp._rowwise_scale_inv
out_scale_inv = out._rowwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
)
# Transfer amax to output.
out._amax_rowwise = inp._amax_rowwise
# Gather the transposed NVFP4 data along first dimension. Fix format later.
if quantizer.columnwise_usage:
# Remove padding from NVFP4 scale-inverses
# For doing an all-gather on transposed scale inverses,
# we need to remove padding from both dimension.
in_scale_inv = inp._columnwise_scale_inv
# take caution that for in_shape_t, flatten in the trailing dimensions!
flattened_in_shape0 = in_shape_t[0]
flattened_in_shape1 = math.prod(in_shape_t[1:])
# Remove dim0 padding
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
# Remove dim1 padding (pack first).
unpadded_dim1 = flattened_in_shape1 * 2 // 16
if in_scale_inv.size(1) != unpadded_dim1:
in_scale_inv = in_scale_inv[:, :unpadded_dim1].contiguous()
# Construct tensor to gather transposed scale_inv (interleaved) and launch AG.
out_scale_inv = torch.empty(
[flattened_in_shape0 * world_size] + [in_scale_inv.shape[1]],
dtype=in_scale_inv.dtype,
layout=in_scale_inv.layout,
device=in_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
# Construct tensor to gather transposed data (interleaved) and launch AG.
out_columnwise_data = torch.empty(
[inp._columnwise_data.shape[0] * world_size] + list(inp._columnwise_data.shape[1:]),
dtype=inp._columnwise_data.dtype,
layout=inp._columnwise_data.layout,
device=inp._columnwise_data.device,
)
torch.distributed.all_gather_into_tensor(
out_columnwise_data,
inp._columnwise_data,
group=process_group,
)
# Transfer amax to output.
out._amax_columnwise = inp._amax_columnwise
handle = gather_coalescing_manager if async_op else None
# Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed.
if async_op and quantizer.columnwise_usage:
handle = _NVFP4AllGatherAsyncHandle(
out, out_columnwise_data, out_scale_inv, world_size, handle
)
elif quantizer.columnwise_usage:
_post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle)
return out, handle
def _all_gather_mxfp8( def _all_gather_mxfp8(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
...@@ -1291,7 +1533,6 @@ def _all_gather_mxfp8( ...@@ -1291,7 +1533,6 @@ def _all_gather_mxfp8(
flattened_in_shape0 = math.prod(in_shape[:-1]) flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0: if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0] in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers # Launch all-gathers
...@@ -1315,7 +1556,6 @@ def _all_gather_mxfp8( ...@@ -1315,7 +1556,6 @@ def _all_gather_mxfp8(
flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
if in_scale_inv.size(0) != flattened_in_shape0: if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0] in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers # Launch all-gathers
...@@ -1347,7 +1587,7 @@ def gather_along_first_dim( ...@@ -1347,7 +1587,7 @@ def gather_along_first_dim(
# Return immediately if no communication is required # Return immediately if no communication is required
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
if world_size == 1: if world_size == 1:
if quantizer is not None and not isinstance(inp, QuantizedTensor): if quantizer is not None and not isinstance(inp, QuantizedTensorBase):
inp = quantizer(inp) inp = quantizer(inp)
return inp, None return inp, None
...@@ -1426,13 +1666,24 @@ def gather_along_first_dim( ...@@ -1426,13 +1666,24 @@ def gather_along_first_dim(
out_shape=out_shape, out_shape=out_shape,
) )
# NVFP4 case
if isinstance(inp, NVFP4TensorBase) or isinstance(quantizer, NVFP4Quantizer):
assert isinstance(quantizer, NVFP4Quantizer)
return _all_gather_nvfp4(
inp,
process_group,
async_op=async_op,
quantizer=quantizer,
out_shape=out_shape,
)
# High-precision communication for quantized tensors # High-precision communication for quantized tensors
if quantizer is not None: if quantizer is not None:
warnings.warn( warnings.warn(
"Attempting to all-gather an unsupported quantized tensor. " "Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather." "Falling back to high-precision all-gather."
) )
if isinstance(inp, QuantizedTensor): if isinstance(inp, QuantizedTensorBase):
inp = inp.dequantize() inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer # Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format # means that it should directly output GEMM_READY format
...@@ -1450,7 +1701,7 @@ def gather_along_first_dim( ...@@ -1450,7 +1701,7 @@ def gather_along_first_dim(
return out, None return out, None
# Dequantize quantized tensor if not supported # Dequantize quantized tensor if not supported
if isinstance(inp, QuantizedTensor): if isinstance(inp, QuantizedTensorBase):
warnings.warn( warnings.warn(
"Attempting to all-gather an unsupported quantized tensor. " "Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather." "Falling back to high-precision all-gather."
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Experimental features and APIs."""
from .config import set_qlinear_params, get_experimental_quantizers
__all__ = ["set_qlinear_params", "get_experimental_quantizers"]
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Config API for experimental middleware between Transformer Engine and Kitchen."""
import dataclasses
import enum
import os
from typing import Optional
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.experimental import quantization
from transformer_engine.pytorch.experimental import quantization_microblock_ref
from transformer_engine.pytorch.experimental.quantization import MMParams
@dataclasses.dataclass()
class QLinearParams:
"""Quantization parameters of linear layer.
Contains ready-to-use quantizers for input (x), weight (w), and gradient (g) tensors.
"""
x_quantizer: Optional[quantization.ExperimentalQuantizer] = None
w_quantizer: Optional[quantization.ExperimentalQuantizer] = None
g_quantizer: Optional[quantization.ExperimentalQuantizer] = None
mm_fprop: Optional[MMParams] = None
mm_dgrad: Optional[MMParams] = None
mm_wgrad: Optional[MMParams] = None
@enum.unique
class QuantizeRecipe(enum.Enum):
"""Pre-defined quantization recipes for linear layers."""
NON_QUANTIZE = "non_quantize"
NVFP4_REF = "nvfp4_ref"
NVFP4_REF_RHT_ONLY = "nvfp4_ref_rht_only"
NVFP4_REF_2D_QUANTIZATION_ONLY = "nvfp4_ref_2d_quantization_only"
NVFP4_REF_RHT_AND_2D_QUANTIZATION = "nvfp4_ref_rht_and_2d_quantization"
def get_qlinear_params_from_predefined(
recipe: QuantizeRecipe,
) -> Optional[QLinearParams]:
"""Get quantization parameters for linear layer based on recipe."""
if recipe == QuantizeRecipe.NON_QUANTIZE:
return None
if recipe == QuantizeRecipe.NVFP4_REF:
return QLinearParams(
x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
),
w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
),
g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
),
)
if recipe == QuantizeRecipe.NVFP4_REF_RHT_ONLY:
return QLinearParams(
x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
),
w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=False,
),
g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
),
)
if recipe == QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY:
return QLinearParams(
x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=False,
),
w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
),
g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=False,
),
)
if recipe == QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION:
return QLinearParams(
x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
),
w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
),
g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
),
)
raise ValueError(f"Unsupported quantize recipe: {recipe}")
def get_qlinear_params_from_qat_params(qat_params_idx: int) -> Optional[QLinearParams]:
"""Load quantization options from Kitchen to Transformer Engine.
TODO(etsykunov): Confirm docstring is correct.
"""
assert qat_params_idx > 0, "QAT_PARAMS is not set."
if qat_params_idx == 6010:
return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF)
if qat_params_idx == 960109:
return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_ONLY)
if qat_params_idx == 9002:
return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY)
if qat_params_idx == 9003:
return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION)
raise ValueError(f"Unsupported QAT params index: {qat_params_idx}")
def set_qlinear_params(
qlinear_params: Optional[QLinearParams] = None,
layer_number: Optional[int] = None,
layer_name: Optional[str] = None,
) -> Optional[QLinearParams]:
"""Set quantization parameters based on configuration.
Args:
qlinear_params: Quantization parameters. If None, loaded from environment.
layer_number: The numerical index of this layer in the model structure.
layer_name: The name for this layer.
Returns:
QLinearParams: The finalized quantization parameters for this layer.
"""
if qlinear_params is None:
qat_params_idx = int(os.getenv("QAT_PARAMS", "0"))
if qat_params_idx == 0:
return None
return get_qlinear_params_from_qat_params(qat_params_idx)
# Apply layer-specific overrides
if layer_number is not None:
raise NotImplementedError("Layer-specific overrides are not supported yet.")
if layer_name is not None:
raise NotImplementedError("Layer-specific overrides are not supported yet.")
return qlinear_params
def get_experimental_quantizers(fp8: bool, qlinear_params: QLinearParams):
"""Replacement of _get_quantizers() in TE modules."""
if not fp8:
raise ValueError("FP8 is required to be enabled for experimental quantization.")
input_quantizer = qlinear_params.x_quantizer
weight_quantizer = qlinear_params.w_quantizer
output_quantizer = None
grad_input_quantizer = None
grad_weight_quantizer = None
grad_output_quantizer = qlinear_params.g_quantizer
return (
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""GEMM API for experimental middleware between Transformer Engine and Kitchen."""
from typing import Iterable, Optional
import torch
from transformer_engine.pytorch.experimental.quantization import (
MMParams,
GEMMType,
ExperimentalQuantizedTensor,
)
from transformer_engine.pytorch.tensor.quantized_tensor import Quantizer
def experimental_gemm(
A: ExperimentalQuantizedTensor,
B: ExperimentalQuantizedTensor,
workspace: torch.Tensor, # pylint: disable=unused-argument
out_dtype: Optional[torch.dtype] = None,
quantization_params: Optional[Quantizer] = None, # pylint: disable=unused-argument
gelu: bool = False, # pylint: disable=unused-argument
gelu_in: torch.Tensor = None, # pylint: disable=unused-argument
accumulate: bool = False, # pylint: disable=unused-argument
layout: str = "TN",
out: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
bias: Optional[torch.Tensor] = None,
use_split_accumulator: bool = False,
grad: bool = False,
) -> Iterable[Optional[torch.Tensor]]:
"""Dispatch GEMM to quantizer's qgemm method."""
assert isinstance(A, ExperimentalQuantizedTensor) and isinstance(
B, ExperimentalQuantizedTensor
), "A and B must be ExperimentalQuantizedTensor instances"
A, B = B, A
# Determine GEMM type based on grad flag and layout
if not grad:
gemm_type = GEMMType.FPROP
else:
if layout == "NN":
gemm_type = GEMMType.DGRAD
elif layout == "NT":
gemm_type = GEMMType.WGRAD
else:
# Default to FPROP for other layouts
gemm_type = GEMMType.FPROP
# Extract quantizer from QuantizedTensor to get qgemm logic
# TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B.quantizer?
quantizer = None
if hasattr(A, "quantizer") and A.quantizer is not None:
quantizer = A.quantizer
elif hasattr(B, "quantizer") and B.quantizer is not None:
quantizer = B.quantizer
else:
raise ValueError("No quantizer found in QuantizedETensor objects")
# Create MMParams
m_params = MMParams(
out_dtype=out_dtype,
use_split_accumulator=use_split_accumulator,
)
out_dtype = A.dtype if m_params.out_dtype is None else m_params.out_dtype
if gemm_type == GEMMType.FPROP:
qx, sx = A.data, A.scale
qw, sw = B.data, B.scale
assert qx is not None
assert sx is not None
assert qw is not None
assert sw is not None
assert A.original_shape is not None
# Call quantizer's qgemm method
result = quantizer.qgemm(
qx,
qw,
m_params,
out_dtype,
sx,
sw,
bias,
gemm_type=GEMMType.FPROP,
qresult_x=A,
qresult_w=B,
)
if len(A.original_shape) > 2:
# Original input was 3D, so we need to reshape result back to 3D
batch_size = A.original_shape[0]
seq_len = A.original_shape[1]
result = result.view(batch_size, seq_len, result.shape[-1])
elif gemm_type == GEMMType.DGRAD:
qdy, sdy = A.data, A.scale
qw_t, sw_t = B.data_t, B.scale_t
assert qdy is not None
assert sdy is not None
assert qw_t is not None
assert sw_t is not None
result = quantizer.qgemm(
qdy,
qw_t,
m_params,
out_dtype,
sdy,
sw_t,
None,
gemm_type=GEMMType.DGRAD,
qresult_x=A,
qresult_w=B,
)
elif gemm_type == GEMMType.WGRAD:
qdy_t, sdy_t = A.data_t, A.scale_t
qx_t, sx_t = B.data_t, B.scale_t
assert qdy_t is not None
assert sdy_t is not None
assert qx_t is not None
assert sx_t is not None
result = quantizer.qgemm(
qdy_t,
qx_t,
m_params,
out_dtype,
sdy_t,
sx_t,
None,
gemm_type=GEMMType.WGRAD,
qresult_x=A,
qresult_w=B,
)
# Return in the same format as general_gemm
return result, None, None, None
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Quantization API for experimental middleware between Transformer Engine and Kitchen."""
from __future__ import annotations
import abc
import dataclasses
import enum
from typing import Iterable, Optional, Tuple, Union
import torch
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from transformer_engine.pytorch.experimental import utils
@enum.unique
class GEMMType(enum.Enum):
"""Type of GEMM operation being performed."""
FPROP = "fprop"
DGRAD = "dgrad"
WGRAD = "wgrad"
@dataclasses.dataclass(frozen=True)
class MMParams:
"""Matrix multiplication parameters."""
out_dtype: torch.dtype | None = None
# Use split accumulator for more accurate FP8 GEMM
use_split_accumulator: bool = True
@dataclasses.dataclass
class ExperimentalQuantizedTensor(QuantizedTensorBase):
"""Base class for experimental quantized tensor containers.
An experimental container to hold quantization result, including quantized tensor, optional
transposed quantized tensor, and corresponding decoding scales.
data: torch.Tensor
the quantized tensor.
scale: torch.Tensor
the decoding scale for the quantized tensor. Shape depends on the scaling granularity.
- if scaling type is PER_TENSOR, it should be a 1D scalar tensor.
data_t: torch.Tensor
the transposed quantized tensor (computed lazily if needed).
scale_t: torch.Tensor
the decoding scale for the transposed quantized tensor.
dtype: torch.dtype
nominal tensor datatype.
device: torch.device
device of the tensor.
quant_dtype: Union[utils.Fp4Formats, torch.dtype]
low precision tensor datatype.
original_shape: Tuple[int, ...]
original shape of the tensor.
quantizer: ExperimentalQuantizer
Builder class for quantized tensor.
"""
data: Optional[torch.Tensor] = None
scale: Optional[torch.Tensor] = None
data_t: Optional[torch.Tensor] = None
scale_t: Optional[torch.Tensor] = None
global_amax_row: Optional[torch.Tensor] = None
global_amax_col: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None
device: Optional[torch.device] = None
quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None
original_shape: Optional[Tuple[int, ...]] = None
quantizer: Optional[ExperimentalQuantizer] = None
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware."""
return True
def get_quantizer(self) -> ExperimentalQuantizer:
"""Get builder for QuantizedExperimentalTensor
Quantizer can be used for in-place operations.
"""
if self.quantizer is not None:
return self.quantizer
raise ValueError("Quantizer is not set")
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], ExperimentalQuantizedTensor]:
"""Prepare the quantization result for saving for backward"""
tensors = [self.data, self.data_t, self.scale, self.scale_t]
self.data = None
self.data_t = None
self.scale = None
self.scale_t = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the quantization result from the saved tensors"""
self.data = tensors[0]
self.data_t = tensors[1]
self.scale = tensors[2]
self.scale_t = tensors[3]
return tensors[4:]
def dequantize(self, *args, **kwargs) -> torch.Tensor:
"""Dequantize the quantized tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement dequantize function"
)
# Compatibility
@property
def _data(self):
return self.data
@_data.setter
def _data(self, value):
self.data = value
@property
def _scale_inv(self):
return self.scale
@_scale_inv.setter
def _scale_inv(self, value):
self.scale = value
class ExperimentalQuantizer(Quantizer):
"""Experimental Quantizer class
Defines the interface for experimental quantizers.
"""
def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.internal = True
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware"""
return True
@abc.abstractmethod
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
m_params: MMParams,
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
gemm_type: GEMMType = GEMMType.FPROP,
qresult_x: ExperimentalQuantizedTensor | None = None,
qresult_w: ExperimentalQuantizedTensor | None = None,
) -> torch.Tensor:
"""Quantized GEMM interface."""
def dequantize(self, *args, **kwargs) -> torch.Tensor:
"""Dequantize the quantized tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement dequantize function"
)
def update_quantized(self, *args, **kwargs) -> torch.Tensor:
"""Update the quantized tensor with the given tensor in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_quantized function"
)
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> QuantizedTensorBase:
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement make_empty function"
)
def calibrate(self, tensor: torch.Tensor) -> None:
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement calibrate function"
)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_compatible_recipe function"
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFP4 implementations for experimental middleware between Transformer Engine and Kitchen."""
from typing import Optional, Tuple
import torch
from transformer_engine.pytorch.experimental import quantization
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.experimental.quantization import (
ExperimentalQuantizedTensor,
ExperimentalQuantizer,
)
def cast_to_fp4x2(x):
"""Quantize a tensor to FP4 E2M1 and store in a byte tensor"""
result = torch.zeros_like(x, dtype=torch.uint8)
result[(x >= 0.0) & (x <= 0.25)] = 0
result[(x > 0.25) & (x < 0.75)] = 1
result[(x >= 0.75) & (x <= 1.25)] = 2
result[(x > 1.25) & (x < 1.75)] = 3
result[(x >= 1.75) & (x <= 2.5)] = 4
result[(x > 2.5) & (x < 3.5)] = 5
result[(x >= 3.5) & (x <= 5.0)] = 6
result[x > 5.0] = 7
result[(x >= -0.25) & (x < -0.0)] = 8
result[(x < -0.25) & (x > -0.75)] = 9
result[(x <= -0.75) & (x >= -1.25)] = 10
result[(x < -1.25) & (x > -1.75)] = 11
result[(x <= -1.75) & (x >= -2.5)] = 12
result[(x < -2.5) & (x > -3.5)] = 13
result[(x <= -3.5) & (x >= -5.0)] = 14
result[x < -5.0] = 15
return result[:, ::2] + result[:, 1::2] * 16
def cast_from_fp4x2(x, dq_dtype):
"""Dequantize FP4 E2M1 tensor that has been represented in a byte tensor"""
fp4_values = torch.tensor(
[
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
],
device=x.device,
dtype=dq_dtype,
)
# Convert to long integers for indexing
second_bit = torch.div(x, 16, rounding_mode="floor").to(torch.long)
first_bit = (x - second_bit * 16).to(torch.long)
# Use the long integers to index fp4_values
first_bit_values = fp4_values[first_bit]
second_bit_values = fp4_values[second_bit]
result = torch.zeros(
(first_bit_values.shape[0], first_bit_values.shape[1] * 2),
device=x.device,
dtype=dq_dtype,
)
result[:, ::2] = first_bit_values
result[:, 1::2] = second_bit_values
return result
def cast_to_e8(decode_scale):
"""Cast to a value that is representable in FP8 E8M0.
The result is in FP32, not FP8 E8M0.
"""
max_exponent = torch.tensor(127, device=decode_scale.device, dtype=torch.float32)
exponent = torch.ceil(torch.log2(decode_scale))
exponent = torch.clamp(exponent, min=-max_exponent, max=max_exponent)
return torch.tensor(2.0, device=decode_scale.device, dtype=torch.float32) ** exponent
def cast_to_e4m3(decode_scale, global_amax):
"""Scale and cast to FP8 E4M3.
decode_scale is actually the encoding scaling factor. global_amax
can be any data tensor and not just the amax.
TODO(etsykunov): Make less unintuitive.
"""
decode_scale = decode_scale * global_amax
FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32)
decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX)
return decode_scale.to(torch.float8_e4m3fn)
def high_precision_gemm_ref(
a: torch.Tensor,
b: torch.Tensor,
out_dtype: torch.dtype,
accumulate: bool = False,
is_a_transposed: bool = False,
is_b_transposed: bool = False,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
scale_alpha: float = 1.0,
) -> torch.Tensor:
"""GEMM implementation with unquantized data"""
# Handle transpositions
mat1, mat2 = a, b
if is_a_transposed:
mat1 = a.T
if is_b_transposed:
mat2 = b.T
# Ensure dtype compatibility for torch.addmm
mat1 = mat1.to(out_dtype)
mat2 = mat2.to(out_dtype)
# Determine output shape
y_shape = (mat1.size(0), mat2.size(1))
if bias is not None:
assert not accumulate, "Bias is not supported with accumulation"
bias = bias.to(out_dtype)
# With bias case
if out_dtype == torch.float32:
y_ref = torch.addmm(bias.repeat(mat1.size(0), 1), mat1, mat2, beta=1, alpha=1)
else:
y_ref = torch.addmm(bias, mat1, mat2, beta=1, alpha=scale_alpha)
else:
# Without bias case
if accumulate and out is not None:
y_ref = out.clone().to(out_dtype)
else:
y_ref = torch.zeros(y_shape, dtype=out_dtype, device=a.device)
torch.addmm(y_ref, mat1, mat2, beta=1, alpha=scale_alpha, out=y_ref)
return y_ref
class NVFP4TensorRef(ExperimentalQuantizedTensor):
"""NVFP4 tensor for middleware between Transformer Engine and Kitchen"""
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"dtype={self.dtype}, "
f"device={self.device}, "
f"quant_dtype={self.quant_dtype}, "
f"data={self.dequantize(dtype=self.dtype)}, "
f"original_shape={self.original_shape}"
")"
)
def quantize_(
self,
tensor: torch.Tensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> ExperimentalQuantizedTensor:
"""In-place update of quantized data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if isinstance(tensor, ExperimentalQuantizedTensor):
return self.quantize_(tensor.dequantize(), noop_flag=noop_flag)
self.get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from quantized tensor
"""
if dtype is None:
dtype = self.dtype
# Ignore data_t for now
assert self.data is not None, "QuantizedTensor has no valid tensor data"
assert self.scale is not None, "QuantizedTensor has no valid scale"
tensor_data = self.data
tensor_scale = self.scale
# Dispatch to the quantizer
return self.get_quantizer().dequantize(tensor_data, tensor_scale, dtype=dtype)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""Generate or remove quantized data based on provided usage."""
has_data = self.data is not None
has_data_transpose = self.data_t is not None
needs_data = has_data
needs_data_transpose = has_data_transpose
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self.data = None
if not needs_data_transpose:
self.data_t = None
def _create_transpose(self):
"""Create transposed quantized tensor"""
if not self.data.is_contiguous():
self.data = self.data.contiguous()
self.data_t = self.data.t().contiguous()
self.scale_t = self.scale
def size(self, *args, **kwargs): # pylint: disable=unused-argument
"""Return the original tensor shape, not the internal packed data shape.
FP4 quantization packs two 4-bit values into each 8-bit value, which reduces
the second dimension by half. This method returns the logical shape that
users expect, not the internal packed storage shape.
"""
assert self.original_shape is not None
return torch.Size(self.original_shape)
def get_wgrad_sign_vector() -> torch.Tensor:
"""Hard-coded signs for Hadamard transform"""
return torch.tensor(
[1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0],
dtype=torch.float32,
)
class NVFP4QuantizerRef(ExperimentalQuantizer):
"""NVFP4 quantizer for middleware between Transformer Engine and Kitchen"""
def __init__(
self,
dtype: utils.Fp4Formats,
rowwise: bool = True,
columnwise: bool = True,
pow_2_scales: bool = False,
eps: float = 0.0,
quant_tile_shape: Tuple[int, int] = (1, 16),
with_rht: bool = False,
with_random_sign_mask: bool = True,
):
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = dtype
self.pow_2_scales = pow_2_scales
self.eps = eps
self.quant_tile_shape = quant_tile_shape
self.with_rht = with_rht
self.with_random_sign_mask = with_random_sign_mask
@staticmethod
def _build_hadamard_matrix(
size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True
) -> torch.Tensor:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert (size & (size - 1)) == 0, "Hadamard size must be a power of two"
h = torch.ones((1, 1), device=device, dtype=torch.float32)
while h.shape[0] < size:
h = torch.cat(
[
torch.cat([h, h], dim=1),
torch.cat([h, -h], dim=1),
],
dim=0,
)
if with_random_sign_mask:
sign_mat = get_wgrad_sign_vector().to(device) * torch.eye(
size, device=device, dtype=torch.float32
)
h = sign_mat @ h
return h.to(dtype)
def _apply_rht(self, x: torch.Tensor) -> torch.Tensor:
"""Apply randomized Hadamard transform without random signs (reference path).
This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))).
"""
# Only apply when enabled
if not self.with_rht:
return x
# RHT dimension equals the quantization tile length (NVFP4 uses 16)
rht_dim = self.quant_tile_shape[1]
assert (
x.shape[-1] % rht_dim == 0
), f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}"
# Build H and scale
H = self._build_hadamard_matrix(rht_dim, x.device, x.dtype, self.with_random_sign_mask)
scale = 1.0 / float(rht_dim) ** 0.5
# Perform blockwise transform along the last dimension
original_shape = x.shape
x_mat = x.contiguous().view(-1, rht_dim)
# Random sign matrix is identity in this reference (no sign flipping)
transform = H * scale
out = x_mat @ transform
return out.view(original_shape)
@staticmethod
def _recover_swizzled_scales(
swizzled_scale: bool, scale: torch.Tensor, m: int, n: int, block_length: int
) -> torch.Tensor:
if not swizzled_scale:
return scale
rounded_m = utils.roundup_div(m, 128) * 128
scale_n = utils.roundup_div(n, block_length)
rounded_n = utils.roundup_div(scale_n, 4) * 4
# Recover swizzled scaling factor layout -> linear layout
tmp = torch.reshape(scale, (rounded_m // 128, rounded_n // 4, 32, 4, 4))
# after permutation, the layout is [rounded_m // 128, 4, 32, rounded_n // 4, 4]
tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
result = torch.reshape(tmp, (rounded_m, rounded_n))
return result[:m, :scale_n]
@classmethod
def _quantize_blockwise_reference(
cls,
x: torch.Tensor,
global_amax: torch.Tensor,
tile_len_x: int,
tile_len_y: int,
*,
pow_2_scales: bool,
eps: float, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.ndim == 2
using_2d_quantization = tile_len_x == 16 and tile_len_y == 16
m, n = x.shape
# Compute vec_max based on the original x (before reshape)
# For 1D quantization: amax over each row chunk of 16
# For 2D quantization: amax over each 16x16 block, but output shape is still (128, 8, 1), filled with block amax
if using_2d_quantization:
# x shape: (128, 128)
x_blocks = (
x.unfold(0, tile_len_y, tile_len_y)
.unfold(1, tile_len_x, tile_len_x)
.to(torch.float32)
) # (8, 8, 16, 16)
block_amax = torch.amax(torch.abs(x_blocks), dim=(-1, -2)) # (8, 8)
# Now, expand to (128, 8, 1) by repeating each block_amax for 16 rows
vec_max = block_amax.repeat_interleave(tile_len_y, dim=0).unsqueeze(-1) # (128, 8, 1)
else:
# x shape: (128, 128)
x_reshaped = x.view(m, n // tile_len_x, tile_len_x) # (128, 8, 16)
vec_max = torch.amax(torch.abs(x_reshaped), dim=-1, keepdim=True).to(
torch.float32
) # (128, 8, 1)
x = x.view(m, n // tile_len_x, tile_len_x)
FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32)
FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32)
decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX)
if pow_2_scales:
decode_scale = cast_to_e8(decode_scale)
encode_scale = torch.div(
torch.tensor(1.0, device=x.device, dtype=torch.float32),
decode_scale.to(torch.float32),
)
else:
global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax)
global_encode_scale = torch.min(
global_encode_scale,
torch.tensor(
torch.finfo(torch.float32).max,
device=global_encode_scale.device,
dtype=torch.float32,
),
)
if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32):
global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32)
global_decode_scale = torch.div(1.0, global_encode_scale)
decode_scale = decode_scale * global_encode_scale
decode_scale = torch.min(
decode_scale,
torch.tensor(
torch.finfo(torch.float32).max,
device=decode_scale.device,
dtype=torch.float32,
),
)
decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX)
decode_scale = decode_scale.to(torch.float8_e4m3fn)
encode_scale = torch.min(
torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale),
torch.tensor(
torch.finfo(torch.float32).max,
device=decode_scale.device,
dtype=torch.float32,
),
)
scaled_x = x.to(torch.float32) * encode_scale
clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n)
return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1)
@staticmethod
def _pad_tensor(
tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int]
) -> torch.Tensor:
assert tensor.dim() == 2, "only supports 2D tensors"
M, N = tensor.shape
padding_needed_rows = 0
padding_needed_cols = 0
if row_divisor is not None and M % row_divisor != 0:
padding_needed_rows = row_divisor - (M % row_divisor)
# Check and calculate column padding if col_divisor is provided
if col_divisor is not None and N % col_divisor != 0:
padding_needed_cols = col_divisor - (N % col_divisor)
# Return original tensor if no padding is needed
if padding_needed_rows == 0 and padding_needed_cols == 0:
return tensor
# pad the tensor
out = torch.nn.functional.pad(
tensor,
(0, padding_needed_cols, 0, padding_needed_rows),
mode="constant",
value=0.0,
).contiguous()
return out
@staticmethod
def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torch.Tensor:
assert tensor.dim() == 2, "only supports 2D tensors"
M, N = original_size
out = tensor[:M, :N].contiguous()
return out
def _quantize(self, tensor: torch.Tensor) -> Tuple[
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
torch.Tensor,
]:
"""
Python implementation of microblock FP4 quantization.
Parameters
----------
tensor : torch.Tensor
Input tensor to quantize (should be 2D)
Returns
-------
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]
(qx, sx, qx_t, sx_t, global_amax) where:
- qx: quantized data in row-major order (if rowwise_usage), None otherwise
- sx: scale tensor for qx (if rowwise_usage), None otherwise
- qx_t: quantized data in column-major order (if columnwise_usage), None otherwise
- sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise
- global_amax: global amax tensor
"""
if self.pow_2_scales:
assert self.quant_tile_shape == (
1,
32,
), "MXFP4 only supports 1x32 tile shape."
# TODO(etsykunov): Fix bug where global_amax_row and
# global_amax_col are not defined
# global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32)
else:
assert self.quant_tile_shape in (
(1, 16),
(16, 16),
), "NVFP4 only supports 1x16 or 16x16 tile shape."
# Prepare inputs once so we can reuse for both amax and quantization
# Row-input will always be the original input.
row_input = tensor
col_input = (
self._apply_rht(tensor.t().contiguous())
if self.with_rht
else tensor.t().contiguous()
)
# Compute amax for rowwise and columnwise paths separately
global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1)
global_amax_col = (
torch.max(torch.abs(col_input)).to(torch.float32).view(1)
if self.columnwise_usage
else global_amax_row
)
transpose_scales = False
M, N = tensor.shape
if self.rowwise_usage:
x_input = row_input
x_padded = self._pad_tensor(
x_input, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1]
)
qx, sx = self._quantize_blockwise_reference(
x_padded,
global_amax_row,
self.quant_tile_shape[1],
self.quant_tile_shape[0],
pow_2_scales=self.pow_2_scales,
eps=self.eps,
)
if transpose_scales:
sx = sx.T
qx = self._rm_pad_tensor(qx, (M, N // 2))
else:
qx = None
sx = None
if self.columnwise_usage:
x_t = col_input
x_t_padded = self._pad_tensor(
x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1]
)
qx_t, sx_t = self._quantize_blockwise_reference(
x_t_padded,
global_amax_col,
self.quant_tile_shape[1],
self.quant_tile_shape[0],
pow_2_scales=self.pow_2_scales,
eps=self.eps,
)
qx_t = self._rm_pad_tensor(qx_t, (N, M // 2))
if transpose_scales:
sx_t = sx_t.T
else:
qx_t = None
sx_t = None
return qx, sx, qx_t, sx_t, global_amax_row, global_amax_col
def quantize(
self,
tensor: torch.Tensor,
**kwargs, # pylint: disable=unused-argument
) -> NVFP4TensorRef:
# sanity checks
assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype."
# Make it work with 3D tensors
original_shape = tensor.shape
if tensor.ndim > 2:
tensor = tensor.view(-1, tensor.shape[-1])
qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(tensor)
return NVFP4TensorRef(
data=qx,
scale=sx,
data_t=qx_t,
scale_t=sx_t,
global_amax_row=global_amax_row,
global_amax_col=global_amax_col,
dtype=tensor.dtype,
device=tensor.device,
quant_dtype=self.dtype,
quantizer=self,
original_shape=original_shape,
)
def update_quantized(
self,
src: torch.Tensor,
dst: ExperimentalQuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> ExperimentalQuantizedTensor:
"""Update the quantized tensor with the given tensor in-place
Parameters
----------
src: torch.Tensor
Source tensor to copy from
dst: ExperimentalQuantizedTensor
Destination ExperimentalQuantizedTensor to update
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
# Handle noop flag
if noop_flag is not None and noop_flag.item() != 0:
return dst
# Make sure input is in expected format
if not src.is_contiguous():
src = src.contiguous()
# Store the original shape and reshape for processing
original_shape = src.shape
if src.ndim > 2:
src = src.view(-1, src.shape[-1])
qx, sx, qx_t, sx_t, global_amax = self._quantize(src)
# Update the destination with new data
dst.data = qx
dst.scale = sx
dst.data_t = qx_t
dst.scale_t = sx_t
dst.global_amax = global_amax
dst.dtype = src.dtype
dst.quant_dtype = self.dtype
dst.original_shape = original_shape
return dst
@property
def supports_allgather_fp8(self) -> bool:
"""Whether the tensor data can be all-gathered with an FP8 all-gather.
TODO(etsykunov): Confirm docstring is correct. Also, this API
seems too FP8-specific and should be reconsidered.
"""
return False
def transpose_qresult(
self, qresult: quantization.ExperimentalQuantizedTensor
) -> quantization.ExperimentalQuantizedTensor:
"""Convert row-wise data to column-wise data (?)
TODO(etsykunov): Confirm docstring is correct.
"""
raise NotImplementedError("Transpose qresult is not implemented for FP4.")
@property
def supports_dequantize(self) -> bool:
"""Whether quantized tensor can converted to high-precision tensor"""
return False
@property
def is_data_t_transposed_in_memory(self) -> bool:
"""Whether column-wise data is stored in transposed layout.
TODO(etsykunov): Confirm docstring is correct.
"""
raise NotImplementedError("Not implemented yet")
def dequantize(
self, tensor: torch.Tensor, scale: torch.Tensor, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
"""Dequantize the quantized tensor"""
raise NotImplementedError("Not implemented yet")
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
m_params: quantization.MMParams,
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP,
qresult_x: quantization.ExperimentalQuantizedTensor | None = None,
qresult_w: quantization.ExperimentalQuantizedTensor | None = None,
) -> torch.Tensor:
assert bias is None, "Bias is implemented for FP4 GEMM."
high_precision_x = cast_from_fp4x2(qx, out_dtype)
high_precision_w = cast_from_fp4x2(qw, out_dtype)
if self.pow_2_scales:
if sx.dtype == torch.uint8:
# if scaling factor is stored in uint8 container
sx = torch.tensor(2.0, device=sx.device, dtype=torch.float32) ** (
(
sx.to(torch.float32)
- torch.tensor(127, device=sx.device, dtype=torch.float32)
)
)
sw = torch.tensor(2.0, device=sw.device, dtype=torch.float32) ** (
(
sw.to(torch.float32)
- torch.tensor(127, device=sw.device, dtype=torch.float32)
)
)
else:
# if scaling factor is torch.float8_e8m0fnu
sx = sx.to(torch.float32)
sw = sw.to(torch.float32)
alpha = torch.tensor(1.0, device=high_precision_x.device, dtype=torch.float32)
else:
assert qresult_x is not None
assert qresult_w is not None
assert qresult_x.global_amax_row is not None
assert qresult_w.global_amax_col is not None
sx = sx.to(torch.float32)
sw = sw.to(torch.float32)
factor = 6.0 * 6.0 * 448.0 * 448.0
if gemm_type == quantization.GEMMType.WGRAD:
partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col
else:
partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row
alpha = torch.div(partial_alpha, factor).squeeze(-1)
M, K = high_precision_x.shape
N, K_w = high_precision_w.shape
assert K == K_w, "K dimension mismatch between qx and qw"
assert K % 32 == 0, "K dimension must be divisible by 32"
assert N % 8 == 0, "N dimension must be divisible by 8"
block_length = 32 if self.pow_2_scales else 16
grid_k = K // block_length
assert sx.shape == (
M,
K // block_length,
), f"sx shape mismatch: expected ({M}, {K//block_length}), got {sx.shape}"
assert sw.shape == (
N,
K // block_length,
), f"sw shape mismatch: expected ({N}, {K//block_length}), got {sw.shape}"
y = torch.zeros(M, N, dtype=torch.float32, device=qx.device)
# below implementation is to match the FP4 tensor core implementation
# Each output element (i, j) is fp32 accumulation of (K // block_length) inner products
# Each inner product is sx * sw * (1, block_length) x (block_length, 1) with precision in fp32
# Then batch the computation in M, N dimension
for k in range(grid_k):
k_start = k * block_length
k_end = k_start + block_length
qx_block = high_precision_x[:, k_start:k_end].clone().contiguous()
qw_block = high_precision_w[:, k_start:k_end].clone().contiguous()
# Extract scaling factors for the current blocks
sx_block = sx[:, k]
sw_block = sw[:, k]
y += torch.outer(sx_block, sw_block) * high_precision_gemm_ref(
qx_block, qw_block, torch.float32, is_b_transposed=True
)
if not self.pow_2_scales and K > 0:
# only apply global scale for NVFP4 and non-empty cases
y = alpha * y
# accumulation happens at epilogue in float32
if accumulate:
assert out is not None, "Output tensor must be provided for accumulation."
y += out.to(torch.float32)
else:
assert out is None, "Output tensor should be None when accumulate is False."
y = y.to(out_dtype)
return y
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for experimental middleware between Transformer Engine and Kitchen."""
import enum
import torch
HIGH_PRECISION_FLOAT_DTYPES = (
torch.float,
torch.float16,
torch.bfloat16,
torch.float32,
)
class Fp4Formats(enum.Enum):
"""FP4 data format"""
E2M1 = "e2m1"
def roundup_div(x: int, y: int) -> int:
"""Round up division"""
assert x >= 0
assert y > 0
return (x + y - 1) // y
...@@ -21,6 +21,7 @@ from transformer_engine.common.recipe import ( ...@@ -21,6 +21,7 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling, MXFP8BlockScaling,
Float8CurrentScaling, Float8CurrentScaling,
Float8BlockScaling, Float8BlockScaling,
NVFP4BlockScaling,
) )
from .constants import dist_group_type from .constants import dist_group_type
...@@ -53,6 +54,13 @@ def check_mxfp8_support() -> Tuple[bool, str]: ...@@ -53,6 +54,13 @@ def check_mxfp8_support() -> Tuple[bool, str]:
return False, "Device compute capability 10.0 or higher required for MXFP8 execution." return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
def check_nvfp4_support() -> Tuple[bool, str]:
"""Return if nvfp4 support is available"""
if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, ""
return False, "Device compute capability 10.0 or higher required for NVFP4 execution."
def check_fp8_block_scaling_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available""" """Return if fp8 block scaling support is available"""
if ( if (
...@@ -105,6 +113,13 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType ...@@ -105,6 +113,13 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType
return tex.DType.kFloat8E5M2 return tex.DType.kFloat8E5M2
def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType:
"""Get fp4 data type according to recipe and tensor"""
if fp4_recipe.fp4_format == Format.E2M1:
return tex.DType.kFloat4E2M1
raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}")
def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
"""Get max representible FP8 value.""" """Get max representible FP8 value."""
if fp8_recipe.fp8_format == Format.E4M3 or ( if fp8_recipe.fp8_format == Format.E4M3 or (
...@@ -142,6 +157,8 @@ class FP8GlobalStateManager: ...@@ -142,6 +157,8 @@ class FP8GlobalStateManager:
reason_for_no_mxfp8 = "" reason_for_no_mxfp8 = ""
fp8_block_scaling_available = None fp8_block_scaling_available = None
reason_for_no_fp8_block_scaling = None reason_for_no_fp8_block_scaling = None
nvfp4_available = None
reason_for_no_nvfp4 = ""
@classmethod @classmethod
def reset(cls) -> None: def reset(cls) -> None:
...@@ -205,6 +222,13 @@ class FP8GlobalStateManager: ...@@ -205,6 +222,13 @@ class FP8GlobalStateManager:
) )
return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling
@classmethod
def is_nvfp4_available(cls) -> Tuple[bool, str]:
"""Return if NVFP4 support is available."""
if cls.nvfp4_available is None:
cls.nvfp4_available, cls.reason_for_no_nvfp4 = check_nvfp4_support()
return cls.nvfp4_available, cls.reason_for_no_nvfp4
@staticmethod @staticmethod
def get_meta_tensor_key(forward: bool = True) -> str: def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`.""" """Returns scaling key in `fp8_meta`."""
...@@ -481,6 +505,9 @@ class FP8GlobalStateManager: ...@@ -481,6 +505,9 @@ class FP8GlobalStateManager:
if isinstance(fp8_recipe, Float8BlockScaling): if isinstance(fp8_recipe, Float8BlockScaling):
fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available()
assert fp8_block_available, reason_for_no_fp8_block assert fp8_block_available, reason_for_no_fp8_block
if isinstance(fp8_recipe, NVFP4BlockScaling):
nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available()
assert nvfp4_available, reason_for_no_nvfp4
@classmethod @classmethod
def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
...@@ -837,6 +864,8 @@ class RecipeState(abc.ABC): ...@@ -837,6 +864,8 @@ class RecipeState(abc.ABC):
cls = Float8CurrentScalingRecipeState cls = Float8CurrentScalingRecipeState
elif recipe.float8_block_scaling(): elif recipe.float8_block_scaling():
cls = Float8BlockScalingRecipeState cls = Float8BlockScalingRecipeState
elif recipe.nvfp4():
cls = NVFP4BlockScalingRecipeState
else: else:
raise ValueError(f"{recipe.__class__.__name__} is not supported") raise ValueError(f"{recipe.__class__.__name__} is not supported")
return cls( return cls(
...@@ -1084,3 +1113,79 @@ class Float8BlockScalingRecipeState(RecipeState): ...@@ -1084,3 +1113,79 @@ class Float8BlockScalingRecipeState(RecipeState):
] ]
) )
) )
class NVFP4BlockScalingRecipeState(RecipeState):
"""Configuration for NVFP4 quantization.
NVFP4 quantization does not require state.
"""
recipe: NVFP4BlockScaling
mode: str
dtype: tex.DType
def __init__(
self,
recipe: NVFP4BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp4_te_dtype(recipe)
# Allocate buffers
if device is None:
device = torch.device("cuda")
def make_quantizers(self) -> list:
from .tensor.nvfp4_tensor import NVFP4Quantizer
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward. It assumes forward quantizers are
# ordered [input, weight, output, ...] and backward quantizers
# are ordered [grad_output, grad_input, ...]. This doesn't
# play nicely with fusible ops: Linear op doesn't own output
# or grad input quantizers, Quantize op only owns input and
# grad output quantizers.
if self.mode == "forward":
def _make_quantizer(idx: int) -> NVFP4Quantizer:
qparams = (
self.recipe.fp4_quant_fwd_weight
if idx % 3 == 1
else self.recipe.fp4_quant_fwd_inp
)
return NVFP4Quantizer(
fp4_dtype=self.dtype,
rowwise=True,
columnwise=True,
with_rht=qparams.random_hadamard_transform,
with_post_rht_amax=qparams.random_hadamard_transform,
with_2d_quantization=qparams.fp4_2d_quantization,
stochastic_rounding=qparams.stochastic_rounding,
)
return [_make_quantizer(idx) for idx in range(self.num_quantizers)]
if self.mode == "backward":
return [
NVFP4Quantizer(
fp4_dtype=self.dtype,
rowwise=True,
columnwise=True,
with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform,
with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform,
with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization,
stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding,
)
for _ in range(self.num_quantizers)
]
raise RuntimeError(f"Unexpected recipe mode ({self.mode})")
...@@ -4,16 +4,18 @@ ...@@ -4,16 +4,18 @@
"""Internal function used by multiple modules.""" """Internal function used by multiple modules."""
from typing import Any, List, Optional, Tuple, Union, Callable import dataclasses
from dataclasses import dataclass
import queue import queue
from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from .. import experimental
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_default_init_method
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..tensor.utils import is_experimental
from ..utils import get_default_init_method
def _get_normalization_func(normalization: str, forward: bool): def _get_normalization_func(normalization: str, forward: bool):
...@@ -170,7 +172,33 @@ def noop_cat( ...@@ -170,7 +172,33 @@ def noop_cat(
return _NoopCatFunc.apply(dim, *tensors) return _NoopCatFunc.apply(dim, *tensors)
@dataclass def get_module_quantizers(
module: torch.nn.Module,
fp8_output: bool,
fp8_grad: bool,
debug: bool,
):
"""Return the 6-tuple of quantizers for a module in a centralized way.
Routing policy:
- If experimental quantization is enabled via environment and module.fp8 is True,
return experimental quantizers.
- Otherwise, return the module's own quantizers (debug or regular).
"""
if getattr(module, "fp8", False) and is_experimental():
# TODO(etsykunov): Quantizer instantiation should be better
# done in the module's constructor
qlinear_params = experimental.config.set_qlinear_params()
if qlinear_params is not None:
return experimental.config.get_experimental_quantizers(module.fp8, qlinear_params)
if not debug:
return module._get_quantizers(fp8_output, fp8_grad)
return module._get_debug_quantizers(fp8_output, fp8_grad)
@dataclasses.dataclass
class _ParameterInitMeta: class _ParameterInitMeta:
""" """
Stores essential metadata needed to support deferred parameter initialization. Stores essential metadata needed to support deferred parameter initialization.
......
...@@ -27,6 +27,7 @@ from ..fp8 import ( ...@@ -27,6 +27,7 @@ from ..fp8 import (
DelayedScalingRecipeState, DelayedScalingRecipeState,
Float8CurrentScalingRecipeState, Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState, Float8BlockScalingRecipeState,
NVFP4BlockScalingRecipeState,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
) )
...@@ -39,6 +40,7 @@ from ..distributed import ( ...@@ -39,6 +40,7 @@ from ..distributed import (
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
...@@ -76,7 +78,8 @@ class UserBufferQuantizationMode(Enum): ...@@ -76,7 +78,8 @@ class UserBufferQuantizationMode(Enum):
def get_cublas_workspace_size_bytes() -> None: def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures.""" """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
return 33_554_432 # 32 MiB for NVFP4 GEMM, plus 256 B for misc scales
return 32 * 1024 * 1024 + 256
return 4_194_304 return 4_194_304
...@@ -757,6 +760,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -757,6 +760,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe_state, Float8BlockScalingRecipeState recipe_state, Float8BlockScalingRecipeState
): ):
return return
if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd # 2 (grad_output and grad_input) for bwd
...@@ -1218,15 +1223,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1218,15 +1223,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
): ):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
if isinstance(quantizer, Float8BlockQuantizer): # TODO(ksivaman): Re-add fusion once kernel is available.
if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)):
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance( if not isinstance(grad_output, QuantizedTensorBase):
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_output = quantizer(grad_output) grad_output = quantizer(grad_output)
return grad_output, grad_bias return grad_output, grad_bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment