Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
...@@ -66,67 +66,102 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -66,67 +66,102 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none);
TensorWrapper bias_cu; TensorWrapper bias_nvte;
if (bias.has_value()) { if (bias.has_value()) {
bias_cu = makeTransformerEngineTensor(*bias); bias_nvte = makeTransformerEngineTensor(*bias);
} }
// Tensor dimensions // Tensor dimensions
const size_t N = static_cast<size_t>(input_cu.size(0)); const auto shape = nvte_shape_to_vector(input_nvte.shape());
const size_t H = static_cast<size_t>(input_cu.size(1)); const auto outer_size = product(shape) / shape.back();
const std::vector<size_t> size = {N, H}; const auto inner_size = shape.back();
// Tensors to save for backward pass // Tensors to save for backward pass
at::Tensor mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor mu_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
at::Tensor rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
TensorWrapper mu_cu = makeTransformerEngineTensor(mu); TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py);
TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor // Output tensor
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_cu; TensorWrapper out_nvte;
if (out.is_none()) { if (out.is_none()) {
std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else { } else {
out_cu = makeTransformerEngineTensor(out, quantizer); out_nvte = makeTransformerEngineTensor(out, quantizer);
} }
// Determine whether to avoid fused kernel // Choose implementation
bool force_unfused_kernel = true; enum class Impl {
if (quantizer.is_none()) { // Compute norm in high precision, then quantize
// No need for separate quantization step if output is unquantized UNFUSED,
force_unfused_kernel = false; // Compute norm directly
} else if (IsFloat8Quantizers(quantizer.ptr())) { FULLY_FUSED,
// Always used fused kernel for FP8 delayed scaling // Compute norm and amax in high precision, then quantize to FP8
force_unfused_kernel = false; FUSED_NORM_AMAX_FP8,
// Compute norm and amax in high precision, then quantize to NVFP4
FUSED_NORM_AMAX_NVFP4
};
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) {
impl = Impl::FULLY_FUSED;
} else if (IsMXFP8Quantizers(quantizer.ptr())) { } else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 &&
// cuDNN MXFP8 kernel requires full tile inner_size % 128 == 0) {
force_unfused_kernel = N % 128 != 0 || H % 128 != 0; // cuDNN MXFP8 kernel requires full 128x128 tiles
impl = Impl::FULLY_FUSED;
}
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
impl = Impl::FUSED_NORM_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else if (!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// TE kernel supports amax output
impl = Impl::FUSED_NORM_AMAX_NVFP4;
} }
} }
TensorWrapper unquantized_out_cu;
// Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte;
py::object unquantized_out; py::object unquantized_out;
if (force_unfused_kernel) { TensorWrapper *kernel_out_nvte = &out_nvte;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && switch (impl) {
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { case Impl::UNFUSED: {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none}; NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
case Impl::FUSED_NORM_AMAX_FP8: {
auto fp8_quantizer_cpp = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
std::tie(unquantized_out_nvte, unquantized_out) =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
case Impl::FUSED_NORM_AMAX_NVFP4: {
auto nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
std::tie(unquantized_out_nvte, unquantized_out) =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
default: {
} }
} }
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size // Query workspace size
TensorWrapper workspace; TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps,
mu_cu.data(), rsigma_cu.data(), workspace.data(), kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
}); });
...@@ -138,24 +173,31 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -138,24 +173,31 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Launch kernel // Launch kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps,
mu_cu.data(), rsigma_cu.data(), workspace.data(), kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
}); });
// Quantize output if using unfused kernel // Quantize output if needed
if (force_unfused_kernel) { switch (impl) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && case Impl::UNFUSED: {
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { quantizer_cpp->quantize(unquantized_out_nvte, out_nvte);
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); } break;
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); case Impl::FUSED_NORM_AMAX_FP8: {
} else { auto fp8_quantizer_cpp = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
my_quantizer->quantize(unquantized_out_cu, out_cu); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte);
} break;
case Impl::FUSED_NORM_AMAX_NVFP4: {
auto nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte);
} break;
default: {
} }
} }
return {out, py::cast(mu), py::cast(rsigma)}; return {out, py::cast(mu_py), py::cast(rsigma_py)};
} }
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
...@@ -254,61 +296,95 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -254,61 +296,95 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none);
// Tensor dimensions // Tensor dimensions
const size_t N = static_cast<size_t>(input_cu.shape().data[0]); const auto shape = nvte_shape_to_vector(input_nvte.shape());
const size_t H = static_cast<size_t>(input_cu.shape().data[1]); const auto outer_size = product(shape) / shape.back();
const std::vector<size_t> size = {N, H}; const auto inner_size = shape.back();
// Tensors to save for backward pass // Tensors to save for backward pass
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
auto rsigma_cu = makeTransformerEngineTensor(rsigma); TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py);
// Output tensor // Output tensor
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
TensorWrapper out_cu; TensorWrapper out_nvte;
if (out.is_none()) { if (out.is_none()) {
std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype);
} else { } else {
out_cu = makeTransformerEngineTensor(out, quantizer); out_nvte = makeTransformerEngineTensor(out, quantizer);
} }
// Determine whether to avoid fused kernel // Choose implementation
bool force_unfused_kernel = true; enum class Impl {
if (quantizer.is_none()) { // Compute norm in high precision, then quantize
// No need for separate quantization step if output is unquantized UNFUSED,
force_unfused_kernel = false; // Compute norm directly
} else if (IsFloat8Quantizers(quantizer.ptr())) { FULLY_FUSED,
// Always used fused kernel for FP8 delayed scaling // Compute norm and amax in high precision, then quantize to FP8
force_unfused_kernel = false; FUSED_NORM_AMAX_FP8,
// Compute norm and amax in high precision, then quantize to NVFP4
FUSED_NORM_AMAX_NVFP4
};
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) {
impl = Impl::FULLY_FUSED;
} else if (IsMXFP8Quantizers(quantizer.ptr())) { } else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 &&
// cuDNN MXFP8 kernel requires full tile inner_size % 128 == 0) {
force_unfused_kernel = N % 128 != 0 || H % 128 != 0; // cuDNN MXFP8 kernel requires full 128x128 tiles
impl = Impl::FULLY_FUSED;
}
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
impl = Impl::FUSED_NORM_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else if (!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// TE kernel supports amax output
impl = Impl::FUSED_NORM_AMAX_NVFP4;
} }
} }
TensorWrapper unquantized_out_cu;
// Construct unquantized output tensor if needed
TensorWrapper unquantized_out_nvte;
py::object unquantized_out; py::object unquantized_out;
if (force_unfused_kernel) { TensorWrapper *kernel_out_nvte = &out_nvte;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && switch (impl) {
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { case Impl::UNFUSED: {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
} else {
NoneQuantizer q{none}; NoneQuantizer q{none};
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
case Impl::FUSED_NORM_AMAX_FP8: {
auto fp8_quantizer_cpp = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
std::tie(unquantized_out_nvte, unquantized_out) =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
case Impl::FUSED_NORM_AMAX_NVFP4: {
auto nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
std::tie(unquantized_out_nvte, unquantized_out) =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype);
kernel_out_nvte = &unquantized_out_nvte;
} break;
default: {
} }
} }
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace size // Query workspace size
TensorWrapper workspace; TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(),
workspace.data(), rsigma_nvte.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
}); });
...@@ -320,24 +396,30 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -320,24 +396,30 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Launch kernel // Launch kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(),
workspace.data(), rsigma_nvte.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
}); });
// Quantize output if using unfused kernel // Quantize output if needed
if (force_unfused_kernel) { switch (impl) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && case Impl::UNFUSED: {
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) { quantizer_cpp->quantize(unquantized_out_nvte, out_nvte);
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); } break;
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); case Impl::FUSED_NORM_AMAX_FP8: {
} else { auto fp8_quantizer_cpp = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
my_quantizer->quantize(unquantized_out_cu, out_cu); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte);
} break;
case Impl::FUSED_NORM_AMAX_NVFP4: {
auto nvfp4_quantizer_cpp = static_cast<NVFP4Quantizer *>(quantizer_cpp.get());
nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte);
} break;
default: {
} }
} }
return {out, py::none(), py::cast(rsigma)}; return {out, py::none(), py::cast(rsigma_py)};
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -23,15 +23,18 @@ ...@@ -23,15 +23,18 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *Float8TensorBasePythonClass = nullptr; PyTypeObject *Float8TensorStoragePythonClass = nullptr;
PyTypeObject *Float8QuantizerClass = nullptr; PyTypeObject *Float8QuantizerClass = nullptr;
PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8TensorStoragePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorStoragePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
void init_float8_extension() { void init_float8_extension() {
if (Float8TensorPythonClass) return; if (Float8TensorPythonClass) return;
...@@ -43,9 +46,9 @@ void init_float8_extension() { ...@@ -43,9 +46,9 @@ void init_float8_extension() {
Float8TensorPythonClass = Float8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor"));
auto fp8_base_module = auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base"); py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage");
Float8TensorBasePythonClass = reinterpret_cast<PyTypeObject *>( Float8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase")); PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage"));
NVTE_CHECK(Float8TensorPythonClass != nullptr, NVTE_CHECK(Float8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch Float8 extension."); "Internal error: could not initialize pyTorch Float8 extension.");
} }
...@@ -58,38 +61,54 @@ void init_mxfp8_extension() { ...@@ -58,38 +61,54 @@ void init_mxfp8_extension() {
MXFP8TensorPythonClass = MXFP8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor"));
auto fp8_base_module = auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); py::module_::import("transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage");
MXFP8TensorBasePythonClass = reinterpret_cast<PyTypeObject *>( MXFP8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase")); PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorStorage"));
NVTE_CHECK(MXFP8TensorPythonClass != nullptr, NVTE_CHECK(MXFP8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch MXFP8 extension."); "Internal error: could not initialize pyTorch MXFP8 extension.");
} }
void init_float8blockwise_extension() { void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorBasePythonClass) return; if (Float8BlockwiseQTensorStoragePythonClass) return;
auto fp8_module = auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import( auto fp8_base_module = py::module_::import(
"transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"); "transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage");
Float8BlockwiseQuantizerClass = reinterpret_cast<PyTypeObject *>( Float8BlockwiseQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer"));
Float8BlockwiseQTensorBasePythonClass = reinterpret_cast<PyTypeObject *>( Float8BlockwiseQTensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase")); PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorStorage"));
Float8BlockwiseQTensorPythonClass = reinterpret_cast<PyTypeObject *>( Float8BlockwiseQTensorPythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor"));
NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension."); "Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr, NVTE_CHECK(Float8BlockwiseQTensorStoragePythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension."); "Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension."); "Internal error: could not initialize pyTorch float8blockwise extension.");
} }
void init_nvfp4_extensions() {
if (NVFP4TensorPythonClass) return;
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
NVFP4TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor"));
auto nvfp4_base_module =
py::module_::import("transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage");
NVFP4TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorStorage"));
NVTE_CHECK(NVFP4TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch NVFP4 extension.");
}
void init_extension() { void init_extension() {
init_float8_extension(); init_float8_extension();
init_mxfp8_extension(); init_mxfp8_extension();
init_float8blockwise_extension(); init_float8blockwise_extension();
init_nvfp4_extensions();
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -136,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -136,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer")); py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu,
"SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"),
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* Backward of GELU and variants */ /* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
...@@ -159,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -159,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu,
"Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"),
py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* DBias + DAct fusions*/ /* DBias + DAct fusions*/
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
......
...@@ -31,22 +31,21 @@ namespace transformer_engine::pytorch { ...@@ -31,22 +31,21 @@ namespace transformer_engine::pytorch {
} while (false); } while (false);
extern PyTypeObject *Float8TensorPythonClass; extern PyTypeObject *Float8TensorPythonClass;
extern PyTypeObject *Float8TensorBasePythonClass; extern PyTypeObject *Float8TensorStoragePythonClass;
extern PyTypeObject *Float8QuantizerClass; extern PyTypeObject *Float8QuantizerClass;
extern PyTypeObject *Float8CurrentScalingQuantizerClass; extern PyTypeObject *Float8CurrentScalingQuantizerClass;
extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorPythonClass;
extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8TensorStoragePythonClass;
extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *MXFP8QuantizerClass;
extern PyTypeObject *Float8BlockwiseQTensorPythonClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass;
extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; extern PyTypeObject *Float8BlockwiseQTensorStoragePythonClass;
extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *Float8BlockwiseQuantizerClass;
extern PyTypeObject *NVFP4TensorPythonClass;
extern PyTypeObject *NVFP4TensorStoragePythonClass;
extern PyTypeObject *NVFP4QuantizerClass;
void init_extension(); void init_extension();
void init_float8_extension();
void init_mxfp8_extension();
namespace detail { namespace detail {
inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; }
...@@ -56,22 +55,28 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { ...@@ -56,22 +55,28 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) {
} }
inline bool IsFloat8Tensor(PyObject *obj) { inline bool IsFloat8Tensor(PyObject *obj) {
return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorStoragePythonClass;
} }
inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; }
inline bool IsMXFP8Tensor(PyObject *obj) { inline bool IsMXFP8Tensor(PyObject *obj) {
return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorStoragePythonClass;
} }
inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; return Py_TYPE(obj) == Float8BlockwiseQuantizerClass;
} }
inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; }
inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { inline bool IsFloat8BlockwiseQTensor(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass ||
Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; Py_TYPE(obj) == Float8BlockwiseQTensorStoragePythonClass;
}
inline bool IsNVFP4Tensor(PyObject *obj) {
return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorStoragePythonClass;
} }
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer);
...@@ -88,6 +93,8 @@ std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params); ...@@ -88,6 +93,8 @@ std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor,
Quantizer *quantization_params); Quantizer *quantization_params);
TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer);
inline bool IsFloatingPointType(at::ScalarType type) { inline bool IsFloatingPointType(at::ScalarType type) {
return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; return type == at::kFloat || type == at::kHalf || type == at::kBFloat16;
} }
...@@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = { ...@@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = {
std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor,
CreateQuantizer<MXFP8Quantizer>), CreateQuantizer<MXFP8Quantizer>),
std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers,
NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer<Float8BlockQuantizer>)}; NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer<Float8BlockQuantizer>),
std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor,
CreateQuantizer<NVFP4Quantizer>)};
} // namespace detail } // namespace detail
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
......
...@@ -31,8 +31,20 @@ std::vector<T> make_transpose_shape(const std::vector<S>& shape) { ...@@ -31,8 +31,20 @@ std::vector<T> make_transpose_shape(const std::vector<S>& shape) {
return ret; return ret;
} }
/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */
template <typename T = size_t>
std::vector<T> convert_shape_for_fp4(const std::vector<T>& shape) {
std::vector<T> ret;
for (size_t i = 0; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
ret.push_back(shape.back() / 2);
return ret;
}
} // namespace } // namespace
constexpr size_t NVFP4_BLOCK_SIZE = 16;
constexpr size_t MXFP8_BLOCK_SIZE = 32; constexpr size_t MXFP8_BLOCK_SIZE = 32;
Quantizer::Quantizer(const py::handle& quantizer) { Quantizer::Quantizer(const py::handle& quantizer) {
...@@ -140,7 +152,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor( ...@@ -140,7 +152,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
// Construct Python FP8 tensor // Construct Python FP8 tensor
py::object out_py; py::object out_py;
if (internal) { if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass)); py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass));
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer); "quantizer"_a = this->quantizer);
...@@ -345,7 +357,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -345,7 +357,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object data_py = with_data ? py::cast(data_tensor) : py::none();
py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none();
if (internal) { if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass)); py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass));
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer); "quantizer"_a = this->quantizer);
...@@ -376,10 +388,15 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -376,10 +388,15 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
} }
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_hp_tensor_with_amax( std::pair<TensorWrapper, py::object>
const std::vector<size_t>& shape, DType dtype) { Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector<size_t>& shape,
DType dtype,
std::optional<at::Tensor> data) {
amax.zero_(); amax.zero_();
auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value())
: NoneQuantizer(py::none()).create_tensor(shape, dtype);
TensorWrapper out_cpp = std::move(out.first);
py::object out_py = std::move(out.second);
out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax)); getTensorShape(amax));
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
...@@ -613,7 +630,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -613,7 +630,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
py::object ret; py::object ret;
if (internal) { if (internal) {
py::handle Float8BlockwiseQTensorClass( py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass)); reinterpret_cast<PyObject*>(Float8BlockwiseQTensorStoragePythonClass));
ret = Float8BlockwiseQTensorClass( ret = Float8BlockwiseQTensorClass(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
...@@ -899,7 +916,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -899,7 +916,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
} }
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")"); " (got shape=", shape, ")");
const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false);
const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true);
...@@ -933,7 +950,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve ...@@ -933,7 +950,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
// Construct Python MXFP8 tensor // Construct Python MXFP8 tensor
py::object out_py; py::object out_py;
if (internal) { if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorBasePythonClass)); py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass));
out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py,
...@@ -1095,7 +1112,7 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s ...@@ -1095,7 +1112,7 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
auto last_dim = shape.back(); auto last_dim = shape.back();
NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")"); " (got shape=", shape, ")");
std::vector<size_t> scale_shape; std::vector<size_t> scale_shape;
...@@ -1116,4 +1133,573 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s ...@@ -1116,4 +1133,573 @@ std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& s
return scale_shape; return scale_shape;
} }
NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
this->with_rht = quantizer.attr("with_rht").cast<bool>();
this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast<bool>();
this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast<bool>();
this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast<bool>();
// Get amax reduction group if needed for NVFP4 AG
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
if (with_amax_reduction) {
auto group = quantizer.attr("_canonicalized_amax_reduction_group")();
NVTE_CHECK(!group.is_none(), "NVFP4Quantizer could not canonicalize amax reduction group");
amax_reduction_group = group.cast<c10::intrusive_ptr<dist_group_type>>();
}
this->with_amax_reduction = with_amax_reduction;
this->amax_reduction_group = amax_reduction_group;
this->rht_matrix_random_sign_mask_t = quantizer.attr("rht_matrix_random_sign_mask_t").cast<int>();
this->rht_matrix = quantizer.attr("rht_matrix").cast<at::Tensor>();
}
void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const {
// set dtype for rowwise and columnwise data in tensor wrapper
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(this->dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(this->dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
using namespace pybind11::literals;
// Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
}
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ",
NVFP4_BLOCK_SIZE, " (got shape=", shape, ")");
NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0,
"NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE,
" (got shape=", shape, ")");
const auto rowwise_scale_inv_shape = get_scale_shape(shape, false);
const auto columnwise_scale_inv_shape = get_scale_shape(shape, true);
// Allocate tensors
at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise;
at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise;
const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
if (rowwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts);
rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts);
amax_rowwise = at::empty({1}, bit32_tensor_opts);
}
if (columnwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(columnwise_scale_inv_shape.begin(),
columnwise_scale_inv_shape.end());
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std::vector<int64_t> shape_int64_2d = {static_cast<int64_t>(flat_first_dim),
static_cast<int64_t>(flat_last_dim)};
const auto transpose_shape_int64 = make_transpose_shape<int64_t>(shape_int64_2d);
columnwise_data_tensor =
at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts);
columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts);
amax_columnwise = at::empty({1}, bit32_tensor_opts);
}
// Convert tensors to Python
auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object {
return need_cast ? py::cast(tensor) : py::none();
};
auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage);
auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage);
auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage);
auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage);
auto amax_rowwise_py = py_cast(amax_rowwise, rowwise_usage);
auto amax_columnwise_py = py_cast(amax_columnwise, columnwise_usage);
// Construct Python NVFP4 tensor
py::object out_py;
if (internal) {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass));
out_py = NVFP4TensorClass(
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
} else {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass));
out_py = NVFP4TensorClass(
"shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
}
// Construct C++ tensor
TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3,
rowwise_scale_inv_shape);
out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (columnwise_usage) {
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std::vector<size_t> shape_2d = {flat_first_dim, flat_last_dim};
auto col_data_shape_fp4 = make_transpose_shape<size_t>(shape_2d);
out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1,
col_data_shape_fp4);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3,
columnwise_scale_inv_shape);
out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_unquantized_tensor_with_amax(
TensorWrapper& quantized_tensor, DType dtype) {
// Construct tensor
auto shape = convertShape(quantized_tensor.shape());
auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype);
// Register amax pointer from quantized tensor
void* amax_ptr = quantized_tensor.amax();
if (amax_ptr == nullptr) {
amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr;
}
NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor.");
out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1});
// Zero out amax
NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream()));
return {std::move(out_cpp), std::move(out_py)};
}
std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
py::object tensor) const {
NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor.");
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
auto amax_rowwise = get_tensor("_amax_rowwise");
auto amax_columnwise = get_tensor("_amax_columnwise");
NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data.");
// Tensor dimensions, shape means original shape
std::vector<size_t> shape;
if (columnwise_data) {
shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true);
if (rowwise_data) {
auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
}
} else { // Already checked columnwise_data_tensor == true
shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
}
size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
}
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, false);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
if (!amax_rowwise) {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
amax_rowwise = at::empty({1}, opts);
tensor.attr("_amax_rowwise") = *amax_rowwise;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
if (amax_rowwise) {
amax_rowwise.reset();
tensor.attr("_amax_rowwise") = py::none();
}
}
// Coerce column-wise data
if (columnwise_usage) {
if (!columnwise_data) {
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std::vector<int64_t> shape_int64_2d = {static_cast<int64_t>(flat_first_dim),
static_cast<int64_t>(flat_last_dim)};
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto transpose_shape_int64 = make_transpose_shape<int64_t>(shape_int64_2d);
columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
const auto scale_inv_shape = get_scale_shape(shape, true);
const std::vector<int64_t> scale_inv_shape_int64(scale_inv_shape.begin(),
scale_inv_shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
if (!amax_columnwise) {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
amax_columnwise = at::zeros({1}, opts);
tensor.attr("_amax_columnwise") = *amax_columnwise;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
if (amax_columnwise) {
amax_columnwise.reset();
tensor.attr("_amax_columnwise") = py::none();
}
}
// Construct C++ tensor
TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING);
if (rowwise_usage) {
out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape);
out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3,
getTensorShape(*rowwise_scale_inv));
out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (columnwise_usage) {
// enforce 2D shape to avoid [S, B, H] shape and B and be 1
// and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
std::vector<size_t> shape_2d = {flat_first_dim, flat_last_dim};
auto col_data_shape_fp4 = make_transpose_shape<size_t>(shape_2d);
out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1,
col_data_shape_fp4);
out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3,
getTensorShape(*columnwise_scale_inv));
out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
this->set_quantization_params(&out_cpp);
return {std::move(out_cpp), std::move(tensor)};
}
void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag,
bool compute_amax) {
// Nothing to be done if input is empty
if (input.numel() == 0) {
return;
}
auto stream = at::cuda::getCurrentCUDAStream();
QuantizationConfigWrapper quant_config;
if (noop_flag) {
quant_config.set_noop_tensor(noop_flag->data());
}
quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization);
quant_config.set_stochastic_rounding(this->stochastic_rounding);
// We only need RHT for columnwise usage.
// flat first dim and last dim for multi dimensional input
size_t rows = 1;
for (size_t i = 0; i < input.ndim() - 1; ++i) {
rows *= input.size(i);
}
size_t cols = input.size(input.ndim() - 1);
TensorWrapper te_rng_state;
if (this->stochastic_rounding) {
const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
auto rng_state = torch::empty({2}, opts);
philox_unpack(philox_args, static_cast<int64_t*>(rng_state.data_ptr()));
te_rng_state = makeTransformerEngineTensor(rng_state);
quant_config.set_rng_state(te_rng_state.data());
}
// Restriction for the RHT cast fusion kernel.
bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
// Compute amax.
if (this->with_rht) {
if (input.dtype() != DType::kBFloat16) {
NVTE_CHECK(false, "RHT is only supported for bfloat16 input");
}
if (this->with_post_rht_amax) {
// We need:
// 1. Rowwise amax = amax for input
// 2. Columnwise amax = amax for RHT(input.t)
NVTE_SCOPED_GIL_RELEASE({
nvte_hadamard_transform_amax(input.data(), out.data(), 0,
this->rht_matrix_random_sign_mask_t, stream);
});
} else {
// raise error since it's not supported yet
NVTE_CHECK(false, "Pre-RHT amax is not supported yet");
}
} else { // Without RHT
if (compute_amax) {
// Amax pointers
auto rowwise_amax_ptr = out.get_amax().data_ptr;
auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr;
void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr;
NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer");
// Compute amax of input tensor
out.set_amax(amax_ptr, DType::kFloat32, std::vector<size_t>{1});
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); });
out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector<size_t>{1});
// Make sure row-wise and column-wise amaxes match
if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, stream));
}
if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, stream));
}
}
}
// amax reduction
if (this->with_amax_reduction) {
std::vector<at::Tensor> amax_tensors;
// push amax tensors inside if they need to be reduced
auto make_amax_tensor = [](void* data_ptr) {
return at::from_blob(
data_ptr, std::vector<int64_t>{1},
[](void*) {}, // deleter doing nothing since it doesn't own the data
at::device(at::kCUDA).dtype(torch::kFloat32));
};
if (rowwise_usage) {
amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr));
}
if (columnwise_usage) {
amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr));
}
c10d::AllreduceCoalescedOptions opts;
opts.reduceOp = c10d::ReduceOp::MAX;
NVTE_SCOPED_GIL_RELEASE(
{ this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); });
}
if (this->with_rht) {
if (rowwise_usage) {
// For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise
TensorWrapper out_identity(out.scaling_mode());
auto out_identity_data = out.get_rowwise_data();
auto out_identity_scale_inv = out.get_rowwise_scale_inv();
auto out_identity_amax = out.get_amax();
out_identity.set_rowwise_data(out_identity_data.data_ptr,
static_cast<DType>(out_identity_data.dtype),
out_identity_data.shape);
out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr,
static_cast<DType>(out_identity_scale_inv.dtype),
out_identity_scale_inv.shape);
out_identity.set_amax(out_identity_amax.data_ptr, static_cast<DType>(out_identity_amax.dtype),
out_identity_amax.shape);
NVTE_SCOPED_GIL_RELEASE(
{ nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); });
}
if (columnwise_usage) {
// Get the output columnwise data, scale_inv, and amax
auto out_columnwise_data = out.get_columnwise_data();
auto out_columnwise_scale_inv = out.get_columnwise_scale_inv();
// NOTE: should already be populated.
auto out_columnwise_amax = out.get_columnwise_amax();
// Create a wrapper for the columnwise output, as the rowwise output.
// The reason is due to the input `rht_output_t` is already in the transposed layout.
// Thus, we only need a rowwise quantization to generate the columnwise output.
TensorWrapper out_transpose(out.scaling_mode());
// Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail
// need to convert the shape to 2D here
auto colwise_data_shape = out_columnwise_data.shape;
std::vector<size_t> colwise_data_shape_2d;
// shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte
// the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again
// so the multiple 2 get cancelled out
colwise_data_shape_2d.push_back(colwise_data_shape.data[0]);
size_t last_dim = 1;
for (size_t i = 1; i < colwise_data_shape.ndim; ++i) {
last_dim *= colwise_data_shape.data[i];
}
colwise_data_shape_2d.push_back(last_dim);
out_transpose.set_rowwise_data(out_columnwise_data.data_ptr,
static_cast<DType>(out_columnwise_data.dtype),
colwise_data_shape_2d);
out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr,
static_cast<DType>(out_columnwise_scale_inv.dtype),
out_columnwise_scale_inv.shape);
out_transpose.set_amax(out_columnwise_amax.data_ptr,
static_cast<DType>(out_columnwise_amax.dtype),
out_columnwise_amax.shape);
if (!eligible_for_rht_cast_fusion) {
// Invoking fallback RHT kernel.
// If using RHT, then amax will be computed in the RHT step
// If not using RHT, then amax will be computed based on input x
at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout
// This wrapper is going to be passed as input to the quantization kernel.
TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs
rht_output_t =
allocateTorchTensor(static_cast<int>(cols), static_cast<int>(rows), input.dtype());
// NOTE (frsun): This is non-intuitive, we are writing the
// result of transposed RHT to the output of rowwise.
rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(),
std::vector<size_t>{cols, rows});
NVTE_SCOPED_GIL_RELEASE({
// Perform the RHT(input.t), and write to rht_output_cpp.columnwise.
nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0,
this->rht_matrix_random_sign_mask_t, stream);
});
// Quantize kernel will treat everything as rowwise input/output, which is
// intended.
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream);
});
} else {
// RHT cast fusion kernel.
NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0,
"RHT matrix is not set");
auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix);
NVTE_SCOPED_GIL_RELEASE({
nvte_hadamard_transform_cast_fusion_columnwise(
input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream);
});
}
}
} else {
NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); });
}
}
void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag) {
this->quantize_impl(input, out, noop_flag, true);
}
void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) {
// Update output tensor amaxes with input tensor amax
auto input_amax_ptr = input.amax();
auto output_rowwise_amax_ptr = out.get_amax().data_ptr;
auto output_columnwise_amax_ptr = out.get_columnwise_amax().data_ptr;
NVTE_CHECK(input_amax_ptr != nullptr ||
(output_rowwise_amax_ptr == nullptr && output_columnwise_amax_ptr == nullptr),
"Input tensor does not have pre-computed amax");
if (input_amax_ptr != output_rowwise_amax_ptr && input_amax_ptr != nullptr &&
output_rowwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(output_rowwise_amax_ptr, input_amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream()));
}
if (input_amax_ptr != output_columnwise_amax_ptr && input_amax_ptr != nullptr &&
output_columnwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream()));
}
input.set_amax(nullptr, DType::kFloat32, input.defaultShape);
// Perform quantization
this->quantize_impl(input, out, std::nullopt, false);
}
std::vector<size_t> NVFP4Quantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
for (auto s : shape) {
numel *= s;
}
auto last_dim = shape.back();
auto flat_first_dim = numel / last_dim;
NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ",
NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")");
NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0,
"NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE,
" (got shape=", shape, ")");
std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise;
if (rowwise_usage) {
// rowwise scaling factor shape
size_t sinv0 = roundup(flat_first_dim, 128);
size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4);
scale_shape = {sinv0, sinv1};
} else {
// columnwise scaling factor shape
size_t sinv0 = roundup(last_dim, 128);
size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4);
scale_shape = {sinv0, sinv1};
}
return scale_shape;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer ...@@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer
return ret; return ret;
} }
TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) {
const DType dtype = tensor.attr("_fp4_dtype").cast<DType>();
auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING);
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor.");
// Row-scaled data
if (rowwise_usage) {
const auto &data = tensor.attr("_rowwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), dtype,
convert_shape_back_from_fp4(getTensorShape(data), false));
ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv));
ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise));
}
// Column-scaled data
if (columnwise_usage) {
const auto &data = tensor.attr("_columnwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast<at::Tensor>();
ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1,
convert_shape_back_from_fp4(getTensorShape(data), false));
ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3,
getTensorShape(scale_inv));
ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32,
getTensorShape(amax_columnwise));
}
// Quantizer state
quantizer->set_quantization_params(&ret);
return ret;
}
} // namespace detail } // namespace detail
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "util.h" #include "util.h"
#include "common.h" #include "common.h"
#include "common/common.h"
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input, std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) { bool rowwise) {
...@@ -14,22 +15,31 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -14,22 +15,31 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
if (input.scaling_mode() == NVTE_INVALID_SCALING) { if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle."); NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING &&
input.scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt; return std::nullopt;
} }
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8,
"4-bit or 8-bit input required for swizzling scaling factors.");
const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING;
NVTEBasicTensor scale_inv; NVTEBasicTensor scale_inv;
NVTEShape nvte_input_shape;
if (rowwise) { if (rowwise) {
nvte_input_shape = input.shape();
scale_inv = input.get_rowwise_scale_inv(); scale_inv = input.get_rowwise_scale_inv();
} else { } else {
nvte_input_shape = input.get_columnwise_data().shape;
scale_inv = input.get_columnwise_scale_inv(); scale_inv = input.get_columnwise_scale_inv();
} }
auto input_shape = nvte_shape_to_vector(input.shape()); auto input_shape = nvte_shape_to_vector(nvte_input_shape);
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape.");
// Allocate memory for swizzled output. // Allocate memory for swizzled output.
auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
std::vector<int64_t> scale_inv_shape_int; std::vector<int64_t> scale_inv_shape_int;
...@@ -41,36 +51,34 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -41,36 +51,34 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
// Reconstruct input only to avoid swizzling both directions if not needed. // Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant. // The specific dtype used is irrelevant, just needs to be correct bits.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper input_cu(input.scaling_mode());
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper output_cu(input.scaling_mode());
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
if (rowwise) { if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
scale_inv_shape); output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} else { } else {
input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
input_shape); input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
scale_inv_shape); output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shape);
} }
// Launch kernel // Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) { if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
scale_inv_shape);
} else { } else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
scale_inv_shape);
} }
return swizzled_scale_inv; return swizzled_scale_inv;
...@@ -170,3 +178,72 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( ...@@ -170,3 +178,72 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
return buffer; return buffer;
} }
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
using transformer_engine::DIVUP;
// Check input tensor
const NVTEScalingMode scaling_mode = input.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (int i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (int i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
data_shape.ndim = 2;
// Recreate input tensor with rowwise usage
transformer_engine::TensorWrapper input_cu(scaling_mode);
input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
const NVTEBasicTensor scale_inv =
rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
input_cu.set_rowwise_scale_inv(
scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);
// Create output tensor
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
// Output swizzled mxfp8 scaling factor dimensions
const size_t swizzled_scale_inv_first_dim = DIVUP<size_t>(data_flat_first_dim, 128) * 128;
const size_t swizzled_scale_inv_last_dim = DIVUP<size_t>(data_flat_last_dim, 128) * 4;
// Allocate memory for swizzled mxfp8 scaling factors
const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
at::Tensor swizzled_scale_inv = at::empty(
std::vector<int64_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options);
// Set rowwise scaling factors on output
void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{};
swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
swizzled_scale_inv_shape.ndim = 2;
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
swizzled_scale_inv_shape);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input = std::move(output_cu);
return swizzled_scale_inv;
}
...@@ -27,4 +27,16 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -27,4 +27,16 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise); std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise);
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
*
* If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid
* transposing it in memory. Due to differences in how block scaling and mxfp8 store data,
* this requires the calling code to treat the output tensor as having been tranposed in this case.
*
* Returns the swizzled scaling factor of the converted mxfp8 tensor.
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input,
bool rowwise);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
...@@ -36,14 +36,17 @@ from .utils import ( ...@@ -36,14 +36,17 @@ from .utils import (
needs_quantized_gemm, needs_quantized_gemm,
) )
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .quantization import FP8GlobalStateManager, autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.nvfp4_tensor import NVFP4Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .triton.pad import pad_columnwise_scale_inv
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
...@@ -416,8 +419,8 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -416,8 +419,8 @@ class _CheckpointFunction(torch.autograd.Function):
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True activation_recompute=True, recompute_phase=True
), fp8_autocast( ), autocast(
enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe enabled=ctx.fp8, recipe=ctx.fp8_recipe
): ):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
...@@ -751,8 +754,8 @@ def checkpoint( ...@@ -751,8 +754,8 @@ def checkpoint(
def recompute_fn(*args, **kwargs): def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), ( with torch.autograd.enable_grad(), (
te_recompute_ctx te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast( ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, autocast(
enabled=fp8, fp8_recipe=fp8_recipe enabled=fp8, recipe=fp8_recipe
): ):
function(*args, **kwargs) function(*args, **kwargs)
...@@ -904,7 +907,7 @@ def _all_gather_fp8( ...@@ -904,7 +907,7 @@ def _all_gather_fp8(
async_op: bool = False, async_op: bool = False,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None, out_shape: Optional[list[int]] = None,
) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]: ) -> tuple[Float8TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather FP8 tensor along first dimension.""" """All-gather FP8 tensor along first dimension."""
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
...@@ -922,7 +925,7 @@ def _all_gather_fp8( ...@@ -922,7 +925,7 @@ def _all_gather_fp8(
# Cast input tensor to FP8 if needed # Cast input tensor to FP8 if needed
# Note: We cannot directly all-gather the transposed FP8 tensor, # Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose. # so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase): if not isinstance(inp, Float8TensorStorage):
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor # we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer # so we need to disable columnwise usage for the quantizer
...@@ -937,7 +940,7 @@ def _all_gather_fp8( ...@@ -937,7 +940,7 @@ def _all_gather_fp8(
) )
# Construct output tensor # Construct output tensor
out: Float8TensorBase out: Float8TensorStorage
if quantizer is not None: if quantizer is not None:
dtype = torch.float32 dtype = torch.float32
device = "cuda" device = "cuda"
...@@ -955,7 +958,7 @@ def _all_gather_fp8( ...@@ -955,7 +958,7 @@ def _all_gather_fp8(
out._transpose = None out._transpose = None
out._transpose_invalid = True out._transpose_invalid = True
else: else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") raise RuntimeError("Float8TensorStorage is not supported yet without Quantizer")
# Assume scaling factors are identical across ranks # Assume scaling factors are identical across ranks
out._scale_inv = inp._scale_inv out._scale_inv = inp._scale_inv
...@@ -1000,10 +1003,10 @@ def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: ...@@ -1000,10 +1003,10 @@ def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
def _post_process_fp8_blockwise_gather( def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorBase, out: Float8BlockwiseQTensorStorage,
quantizer: Float8BlockQuantizer, quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None, handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorBase: ) -> Float8BlockwiseQTensorStorage:
"""Post-process FP8 blockwise gather.""" """Post-process FP8 blockwise gather."""
if handle is not None: if handle is not None:
handle.wait() handle.wait()
...@@ -1037,7 +1040,7 @@ def _post_process_fp8_blockwise_gather( ...@@ -1037,7 +1040,7 @@ def _post_process_fp8_blockwise_gather(
class _FP8BlockwiseAllGatherAsyncHandle: class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather.""" """Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorBase tensor: Float8BlockwiseQTensorStorage
quantizer: Float8BlockQuantizer quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work async_handle: torch.distributed.Work
_synchronized: bool = False _synchronized: bool = False
...@@ -1075,18 +1078,18 @@ def _all_gather_fp8_blockwise( ...@@ -1075,18 +1078,18 @@ def _all_gather_fp8_blockwise(
if isinstance(inp, torch.Tensor): if isinstance(inp, torch.Tensor):
device = inp.device device = inp.device
dtype = inp.dtype dtype = inp.dtype
elif isinstance(inp, Float8BlockwiseQTensorBase): elif isinstance(inp, Float8BlockwiseQTensorStorage):
if inp._rowwise_data is not None: if inp._rowwise_data is not None:
device = inp._rowwise_data.device device = inp._rowwise_data.device
elif inp._columnwise_data is not None: elif inp._columnwise_data is not None:
device = inp._columnwise_data.device device = inp._columnwise_data.device
else: else:
raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data")
dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant.
else: else:
raise ValueError( raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " "Invalid type for input tensor (expected torch.Tensor or"
f"found {inp.__class__.__name__})" f" Float8BlockwiseQTensorStorage, found {inp.__class__.__name__})"
) )
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
...@@ -1103,7 +1106,7 @@ def _all_gather_fp8_blockwise( ...@@ -1103,7 +1106,7 @@ def _all_gather_fp8_blockwise(
# Doing BF16 gather for now as baseline because it's simpler # Doing BF16 gather for now as baseline because it's simpler
if ( if (
not isinstance(inp, Float8BlockwiseQTensorBase) not isinstance(inp, Float8BlockwiseQTensorStorage)
and quantizer is not None and quantizer is not None
and not quantizer.is_quantizable(inp) and not quantizer.is_quantizable(inp)
): ):
...@@ -1128,7 +1131,7 @@ def _all_gather_fp8_blockwise( ...@@ -1128,7 +1131,7 @@ def _all_gather_fp8_blockwise(
# Set to compact usage in case the quantizer is not correctly configured # Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True quantizer.all_gather_usage = True
if not isinstance(inp, Float8BlockwiseQTensorBase): if not isinstance(inp, Float8BlockwiseQTensorStorage):
inp = quantizer(inp) inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None quantizer.columnwise_usage and inp._columnwise_data is None
...@@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise( ...@@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise(
return out, handle return out, handle
def _swap_first_dims(tensor: torch.Tensor, world_size: int):
"""
Swap first 2 dimensions of a tensor to fix interleaved
data format after gathering transposed data.
For more than 2 dimensions, we squash the trailing dimensions,
instead of the first few dimensions, that's because the shape
passed in this function is already transposed.
"""
shape = tensor.shape
assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave."
first_dim = shape[0]
flattened_trailing = math.prod(shape[1:])
assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave."
tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing)
tensor = tex.swap_first_dims(tensor, out=None)
return tensor.reshape(first_dim // world_size, flattened_trailing * world_size)
def _post_process_nvfp4_gather(
out: NVFP4TensorStorage,
columnwise_data_interleaved: torch.Tensor,
columnwise_scale_inv_interleaved: torch.Tensor,
world_size: int,
handle: Optional[torch.distributed.Work] = None,
) -> NVFP4TensorStorage:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
# Fix the interleaved transposed data from gathering along first dim.
out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
@dataclass
class _NVFP4AllGatherAsyncHandle:
"""Handle for asynchronous NVFP4 all-gather."""
output: NVFP4TensorStorage
columnwise_data_interleaved: torch.Tensor
columnwise_scale_inv_interleaved: torch.Tensor
world_size: int
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_nvfp4_gather(
self.output,
self.columnwise_data_interleaved,
self.columnwise_scale_inv_interleaved,
self.world_size,
)
self._synchronized = True
def _all_gather_nvfp4(
inp: torch.Tensor,
process_group: dist_group_type,
*,
async_op: bool = False,
quantizer: NVFP4Quantizer,
out_shape: Optional[list[int]] = None,
) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather NVFP4 tensor along first dimension."""
# Input tensor attributes
in_shape: Iterable[int] = None
in_shape_t: Iterable[int] = None
device: torch.device
dtype: torch.dtype
# Construct packed shapes for input and input_t.
if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorStorage):
# High-precision tensor.
in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size())
in_shape_t = NVFP4Quantizer.convert_shape_for_fp4(
NVFP4Quantizer.get_columnwise_shape(inp.size())
)
device = inp.device
dtype = inp.dtype
elif isinstance(inp, NVFP4TensorStorage):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
if inp._columnwise_data is not None:
in_shape_t = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = torch.bfloat16
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, "
f"found {inp.__class__.__name__})"
)
assert in_shape is not None or in_shape_t is not None, "No data found."
world_size = get_distributed_world_size(process_group)
if out_shape is None:
out_shape = [in_shape[0] * world_size] + in_shape[1:]
# For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to NVFP4.
if (
not isinstance(inp, NVFP4TensorStorage)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty(
out_shape,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
return out, None
# Cast input tensor to NVFP4 with required data
if not isinstance(inp, NVFP4TensorStorage):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to NVFP4."
)
inp = quantizer(inp.dequantize())
# Construct NVFP4 output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives for gathering data and scale inverses.
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as gather_coalescing_manager:
# Gather NVFP4 data for row-wise usage
if quantizer.rowwise_usage:
# Remove padding from NVFP4 scale-inverses
assert in_shape is not None, "Shape not found."
in_scale_inv = inp._rowwise_scale_inv
out_scale_inv = out._rowwise_scale_inv
flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
)
# Transfer amax to output.
out._amax_rowwise = inp._amax_rowwise
# Gather the transposed NVFP4 data along first dimension. Fix format later.
if quantizer.columnwise_usage:
# Remove padding from NVFP4 scale-inverses
# For doing an all-gather on transposed scale inverses,
# we need to remove padding from both dimension.
in_scale_inv = inp._columnwise_scale_inv
# take caution that for in_shape_t, flatten in the trailing dimensions!
flattened_in_shape0 = in_shape_t[0]
flattened_in_shape1 = math.prod(in_shape_t[1:])
# Remove dim0 padding
if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0]
# Remove dim1 padding (pack first).
unpadded_dim1 = flattened_in_shape1 * 2 // 16
if in_scale_inv.size(1) != unpadded_dim1:
in_scale_inv = in_scale_inv[:, :unpadded_dim1].contiguous()
# Construct tensor to gather transposed scale_inv (interleaved) and launch AG.
out_scale_inv = torch.empty(
[flattened_in_shape0 * world_size] + [in_scale_inv.shape[1]],
dtype=in_scale_inv.dtype,
layout=in_scale_inv.layout,
device=in_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
out_scale_inv,
in_scale_inv,
group=process_group,
)
# Construct tensor to gather transposed data (interleaved) and launch AG.
out_columnwise_data = torch.empty(
[inp._columnwise_data.shape[0] * world_size] + list(inp._columnwise_data.shape[1:]),
dtype=inp._columnwise_data.dtype,
layout=inp._columnwise_data.layout,
device=inp._columnwise_data.device,
)
torch.distributed.all_gather_into_tensor(
out_columnwise_data,
inp._columnwise_data,
group=process_group,
)
# Transfer amax to output.
out._amax_columnwise = inp._amax_columnwise
handle = gather_coalescing_manager if async_op else None
# Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed.
if async_op and quantizer.columnwise_usage:
handle = _NVFP4AllGatherAsyncHandle(
out, out_columnwise_data, out_scale_inv, world_size, handle
)
elif quantizer.columnwise_usage:
_post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle)
return out, handle
def _all_gather_mxfp8( def _all_gather_mxfp8(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
...@@ -1211,7 +1453,7 @@ def _all_gather_mxfp8( ...@@ -1211,7 +1453,7 @@ def _all_gather_mxfp8(
async_op: bool = False, async_op: bool = False,
quantizer: MXFP8Quantizer, quantizer: MXFP8Quantizer,
out_shape: Optional[list[int]] = None, out_shape: Optional[list[int]] = None,
) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: ) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather MXFP8 tensor along first dimension.""" """All-gather MXFP8 tensor along first dimension."""
# Input tensor attributes # Input tensor attributes
...@@ -1222,7 +1464,7 @@ def _all_gather_mxfp8( ...@@ -1222,7 +1464,7 @@ def _all_gather_mxfp8(
in_shape = inp.size() in_shape = inp.size()
device = inp.device device = inp.device
dtype = inp.dtype dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase): elif isinstance(inp, MXFP8TensorStorage):
if inp._rowwise_data is not None: if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size() in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device device = inp._rowwise_data.device
...@@ -1234,7 +1476,7 @@ def _all_gather_mxfp8( ...@@ -1234,7 +1476,7 @@ def _all_gather_mxfp8(
dtype = torch.bfloat16 # Guess high-precision dtype. dtype = torch.bfloat16 # Guess high-precision dtype.
else: else:
raise ValueError( raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, " "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, "
f"found {inp.__class__.__name__})" f"found {inp.__class__.__name__})"
) )
...@@ -1246,7 +1488,7 @@ def _all_gather_mxfp8( ...@@ -1246,7 +1488,7 @@ def _all_gather_mxfp8(
# For cases where inp has dimensions that cannot be quantized, # For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to FP8. # we gather in high precision followed by a cast to FP8.
if ( if (
not isinstance(inp, MXFP8TensorBase) not isinstance(inp, MXFP8TensorStorage)
and quantizer is not None and quantizer is not None
and not quantizer.is_quantizable(inp) and not quantizer.is_quantizable(inp)
): ):
...@@ -1261,7 +1503,7 @@ def _all_gather_mxfp8( ...@@ -1261,7 +1503,7 @@ def _all_gather_mxfp8(
return out, None return out, None
# Cast input tensor to MXFP8 with required data # Cast input tensor to MXFP8 with required data
if not isinstance(inp, MXFP8TensorBase): if not isinstance(inp, MXFP8TensorStorage):
inp = quantizer(inp) inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None quantizer.columnwise_usage and inp._columnwise_data is None
...@@ -1291,7 +1533,6 @@ def _all_gather_mxfp8( ...@@ -1291,7 +1533,6 @@ def _all_gather_mxfp8(
flattened_in_shape0 = math.prod(in_shape[:-1]) flattened_in_shape0 = math.prod(in_shape[:-1])
if in_scale_inv.size(0) != flattened_in_shape0: if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0] in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers # Launch all-gathers
...@@ -1315,7 +1556,6 @@ def _all_gather_mxfp8( ...@@ -1315,7 +1556,6 @@ def _all_gather_mxfp8(
flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 flattened_in_shape0 = math.prod(in_shape[:-1]) // 32
if in_scale_inv.size(0) != flattened_in_shape0: if in_scale_inv.size(0) != flattened_in_shape0:
in_scale_inv = in_scale_inv[:flattened_in_shape0] in_scale_inv = in_scale_inv[:flattened_in_shape0]
out_scale_inv[flattened_in_shape0 * world_size :].zero_()
out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size]
# Launch all-gathers # Launch all-gathers
...@@ -1347,7 +1587,7 @@ def gather_along_first_dim( ...@@ -1347,7 +1587,7 @@ def gather_along_first_dim(
# Return immediately if no communication is required # Return immediately if no communication is required
world_size = get_distributed_world_size(process_group) world_size = get_distributed_world_size(process_group)
if world_size == 1: if world_size == 1:
if quantizer is not None and not isinstance(inp, QuantizedTensor): if quantizer is not None and not isinstance(inp, QuantizedTensorStorage):
inp = quantizer(inp) inp = quantizer(inp)
return inp, None return inp, None
...@@ -1394,7 +1634,7 @@ def gather_along_first_dim( ...@@ -1394,7 +1634,7 @@ def gather_along_first_dim(
out_shape[0] *= world_size out_shape[0] *= world_size
# FP8 case: delayed scaling or current scaling # FP8 case: delayed scaling or current scaling
if isinstance(inp, Float8TensorBase) or isinstance( if isinstance(inp, Float8TensorStorage) or isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
): ):
return _all_gather_fp8( return _all_gather_fp8(
...@@ -1406,7 +1646,9 @@ def gather_along_first_dim( ...@@ -1406,7 +1646,9 @@ def gather_along_first_dim(
) )
# FP8 block scaling case, block length = 128 # FP8 block scaling case, block length = 128
if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance(
quantizer, Float8BlockQuantizer
):
return _all_gather_fp8_blockwise( return _all_gather_fp8_blockwise(
inp, inp,
process_group, process_group,
...@@ -1416,7 +1658,7 @@ def gather_along_first_dim( ...@@ -1416,7 +1658,7 @@ def gather_along_first_dim(
) )
# MXFP8 case # MXFP8 case
if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer):
assert isinstance(quantizer, MXFP8Quantizer) assert isinstance(quantizer, MXFP8Quantizer)
return _all_gather_mxfp8( return _all_gather_mxfp8(
inp, inp,
...@@ -1426,13 +1668,24 @@ def gather_along_first_dim( ...@@ -1426,13 +1668,24 @@ def gather_along_first_dim(
out_shape=out_shape, out_shape=out_shape,
) )
# NVFP4 case
if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer):
assert isinstance(quantizer, NVFP4Quantizer)
return _all_gather_nvfp4(
inp,
process_group,
async_op=async_op,
quantizer=quantizer,
out_shape=out_shape,
)
# High-precision communication for quantized tensors # High-precision communication for quantized tensors
if quantizer is not None: if quantizer is not None:
warnings.warn( warnings.warn(
"Attempting to all-gather an unsupported quantized tensor. " "Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather." "Falling back to high-precision all-gather."
) )
if isinstance(inp, QuantizedTensor): if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer # Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format # means that it should directly output GEMM_READY format
...@@ -1450,7 +1703,7 @@ def gather_along_first_dim( ...@@ -1450,7 +1703,7 @@ def gather_along_first_dim(
return out, None return out, None
# Dequantize quantized tensor if not supported # Dequantize quantized tensor if not supported
if isinstance(inp, QuantizedTensor): if isinstance(inp, QuantizedTensorStorage):
warnings.warn( warnings.warn(
"Attempting to all-gather an unsupported quantized tensor. " "Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather." "Falling back to high-precision all-gather."
...@@ -1720,7 +1973,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -1720,7 +1973,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
if hasattr(fsdp_root, "primary_weights_in_fp8"): if hasattr(fsdp_root, "primary_weights_in_fp8"):
assert not fsdp_root.primary_weights_in_fp8, ( assert not fsdp_root.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context." "Please initialize your model without the te.quantized_model_init(...) context."
) )
root_state = _get_module_fsdp_state(fsdp_root) root_state = _get_module_fsdp_state(fsdp_root)
assert root_state is not None, "Root module does not have a valid _FSDPState." assert root_state is not None, "Root module does not have a valid _FSDPState."
...@@ -1733,7 +1986,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ...@@ -1733,7 +1986,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
if hasattr(fsdp_module.module, "primary_weights_in_fp8"): if hasattr(fsdp_module.module, "primary_weights_in_fp8"):
assert not fsdp_module.module.primary_weights_in_fp8, ( assert not fsdp_module.module.primary_weights_in_fp8, (
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.fp8_model_init(...) context." "Please initialize your model without the te.quantized_model_init(...) context."
) )
setattr(fsdp_module.module, "fsdp_group", state.process_group) setattr(fsdp_module.module, "fsdp_group", state.process_group)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Internal data structures for quantized tensors."""
"""Experimental features and APIs."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""GEMM API for experimental middleware between Transformer Engine and Kitchen."""
from typing import Iterable, Optional
import torch
from transformer_engine.pytorch.experimental.quantization import (
MMParams,
GEMMType,
)
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.tensor.utils import is_experimental
def experimental_gemm(
A: QuantizedTensorStorage,
B: QuantizedTensorStorage,
workspace: torch.Tensor, # pylint: disable=unused-argument
out_dtype: Optional[torch.dtype] = None,
quantization_params: Optional[Quantizer] = None, # pylint: disable=unused-argument
gelu: bool = False, # pylint: disable=unused-argument
gelu_in: torch.Tensor = None, # pylint: disable=unused-argument
accumulate: bool = False, # pylint: disable=unused-argument
layout: str = "TN",
out: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
bias: Optional[torch.Tensor] = None,
use_split_accumulator: bool = False,
grad: bool = False,
) -> Iterable[Optional[torch.Tensor]]:
"""Dispatch GEMM to quantizer's qgemm method."""
assert is_experimental(A) and is_experimental(B), "A and B must be experimental tensors"
A, B = B, A
# Determine GEMM type based on grad flag and layout
if not grad:
gemm_type = GEMMType.FPROP
else:
if layout == "NN":
gemm_type = GEMMType.DGRAD
elif layout == "NT":
gemm_type = GEMMType.WGRAD
else:
# Default to FPROP for other layouts
gemm_type = GEMMType.FPROP
# Extract quantizer from QuantizedTensor to get qgemm logic
# TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B._quantizer?
quantizer = None
if hasattr(A, "_quantizer") and A._quantizer is not None:
quantizer = A._quantizer
elif hasattr(B, "_quantizer") and B._quantizer is not None:
quantizer = B._quantizer
else:
raise ValueError("No quantizer found in QuantizedTensor objects")
# Create MMParams
m_params = MMParams(
out_dtype=out_dtype,
use_split_accumulator=use_split_accumulator,
)
out_dtype = A.dtype if m_params.out_dtype is None else m_params.out_dtype
if gemm_type == GEMMType.FPROP:
qx, sx = A.data, A.scale
qw, sw = B.data, B.scale
assert qx is not None
assert sx is not None
assert qw is not None
assert sw is not None
assert A.original_shape is not None
# Call quantizer's qgemm method
result = quantizer.qgemm(
qx,
qw,
m_params,
out_dtype,
sx,
sw,
bias,
gemm_type=GEMMType.FPROP,
qresult_x=A,
qresult_w=B,
)
if len(A.original_shape) > 2:
# Original input was 3D, so we need to reshape result back to 3D
batch_size = A.original_shape[0]
seq_len = A.original_shape[1]
result = result.view(batch_size, seq_len, result.shape[-1])
elif gemm_type == GEMMType.DGRAD:
qdy, sdy = A.data, A.scale
qw_t, sw_t = B.data_t, B.scale_t
assert qdy is not None
assert sdy is not None
assert qw_t is not None
assert sw_t is not None
result = quantizer.qgemm(
qdy,
qw_t,
m_params,
out_dtype,
sdy,
sw_t,
None,
gemm_type=GEMMType.DGRAD,
qresult_x=A,
qresult_w=B,
)
elif gemm_type == GEMMType.WGRAD:
qdy_t, sdy_t = A.data_t, A.scale_t
qx_t, sx_t = B.data_t, B.scale_t
assert qdy_t is not None
assert sdy_t is not None
assert qx_t is not None
assert sx_t is not None
result = quantizer.qgemm(
qdy_t,
qx_t,
m_params,
out_dtype,
sdy_t,
sx_t,
None,
gemm_type=GEMMType.WGRAD,
qresult_x=A,
qresult_w=B,
)
# Return in the same format as general_gemm
return result, None, None, None
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Quantization API for experimental middleware between Transformer Engine and Kitchen."""
from __future__ import annotations
import dataclasses
import enum
import torch
@enum.unique
class GEMMType(enum.Enum):
"""Type of GEMM operation being performed."""
FPROP = "fprop"
DGRAD = "dgrad"
WGRAD = "wgrad"
@dataclasses.dataclass(frozen=True)
class MMParams:
"""Matrix multiplication parameters."""
out_dtype: torch.dtype | None = None
# Use split accumulator for more accurate FP8 GEMM
use_split_accumulator: bool = True
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFP4 recipe reference implementation."""
import dataclasses
from typing import Optional, Tuple, Union
import torch
from transformer_engine.pytorch.experimental import quantization
from transformer_engine.pytorch.experimental import utils
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
def nvfp4_ref_rht_2d_quantizer_factory(role):
"""
Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights).
Usage with CustomRecipe and fp8_autocast:
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
with fp8_autocast(fp8_recipe=custom_recipe):
output = model(input)
"""
if role == "linear_input":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
)
if role == "linear_weight":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
)
if role == "linear_grad_output":
return NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
)
return None
def cast_to_fp4x2(x):
"""Quantize a tensor to FP4 E2M1 and store in a byte tensor"""
result = torch.zeros_like(x, dtype=torch.uint8)
result[(x >= 0.0) & (x <= 0.25)] = 0
result[(x > 0.25) & (x < 0.75)] = 1
result[(x >= 0.75) & (x <= 1.25)] = 2
result[(x > 1.25) & (x < 1.75)] = 3
result[(x >= 1.75) & (x <= 2.5)] = 4
result[(x > 2.5) & (x < 3.5)] = 5
result[(x >= 3.5) & (x <= 5.0)] = 6
result[x > 5.0] = 7
result[(x >= -0.25) & (x < -0.0)] = 8
result[(x < -0.25) & (x > -0.75)] = 9
result[(x <= -0.75) & (x >= -1.25)] = 10
result[(x < -1.25) & (x > -1.75)] = 11
result[(x <= -1.75) & (x >= -2.5)] = 12
result[(x < -2.5) & (x > -3.5)] = 13
result[(x <= -3.5) & (x >= -5.0)] = 14
result[x < -5.0] = 15
return result[:, ::2] + result[:, 1::2] * 16
def cast_from_fp4x2(x, dq_dtype):
"""Dequantize FP4 E2M1 tensor that has been represented in a byte tensor"""
fp4_values = torch.tensor(
[
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
],
device=x.device,
dtype=dq_dtype,
)
# Convert to long integers for indexing
second_bit = torch.div(x, 16, rounding_mode="floor").to(torch.long)
first_bit = (x - second_bit * 16).to(torch.long)
# Use the long integers to index fp4_values
first_bit_values = fp4_values[first_bit]
second_bit_values = fp4_values[second_bit]
result = torch.zeros(
(first_bit_values.shape[0], first_bit_values.shape[1] * 2),
device=x.device,
dtype=dq_dtype,
)
result[:, ::2] = first_bit_values
result[:, 1::2] = second_bit_values
return result
def cast_to_e8(decode_scale):
"""Cast to a value that is representable in FP8 E8M0.
The result is in FP32, not FP8 E8M0.
"""
max_exponent = torch.tensor(127, device=decode_scale.device, dtype=torch.float32)
exponent = torch.ceil(torch.log2(decode_scale))
exponent = torch.clamp(exponent, min=-max_exponent, max=max_exponent)
return torch.tensor(2.0, device=decode_scale.device, dtype=torch.float32) ** exponent
def cast_to_e4m3(decode_scale, global_amax):
"""Scale and cast to FP8 E4M3.
decode_scale is actually the encoding scaling factor. global_amax
can be any data tensor and not just the amax.
TODO(etsykunov): Make less unintuitive.
"""
decode_scale = decode_scale * global_amax
FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32)
decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX)
return decode_scale.to(torch.float8_e4m3fn)
def high_precision_gemm_ref(
a: torch.Tensor,
b: torch.Tensor,
out_dtype: torch.dtype,
accumulate: bool = False,
is_a_transposed: bool = False,
is_b_transposed: bool = False,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
scale_alpha: float = 1.0,
) -> torch.Tensor:
"""GEMM implementation with unquantized data"""
# Handle transpositions
mat1, mat2 = a, b
if is_a_transposed:
mat1 = a.T
if is_b_transposed:
mat2 = b.T
# Ensure dtype compatibility for torch.addmm
mat1 = mat1.to(out_dtype)
mat2 = mat2.to(out_dtype)
# Determine output shape
y_shape = (mat1.size(0), mat2.size(1))
if bias is not None:
assert not accumulate, "Bias is not supported with accumulation"
bias = bias.to(out_dtype)
# With bias case
if out_dtype == torch.float32:
y_ref = torch.addmm(bias.repeat(mat1.size(0), 1), mat1, mat2, beta=1, alpha=1)
else:
y_ref = torch.addmm(bias, mat1, mat2, beta=1, alpha=scale_alpha)
else:
# Without bias case
if accumulate and out is not None:
y_ref = out.clone().to(out_dtype)
else:
y_ref = torch.zeros(y_shape, dtype=out_dtype, device=a.device)
torch.addmm(y_ref, mat1, mat2, beta=1, alpha=scale_alpha, out=y_ref)
return y_ref
@dataclasses.dataclass
class NVFP4TensorRef(QuantizedTensorStorage):
"""NVFP4 tensor for middleware between Transformer Engine and Kitchen.
Custom container to hold quantization result, including quantized tensor, optional
transposed quantized tensor, and corresponding decoding scales.
data: torch.Tensor
the quantized tensor.
scale: torch.Tensor
the decoding scale for the quantized tensor. Shape depends on the scaling granularity.
- if scaling type is PER_TENSOR, it should be a 1D scalar tensor.
data_t: torch.Tensor
the transposed quantized tensor (computed lazily if needed).
scale_t: torch.Tensor
the decoding scale for the transposed quantized tensor.
dtype: torch.dtype
nominal tensor datatype.
device: torch.device
device of the tensor.
quant_dtype: Union[utils.Fp4Formats, torch.dtype]
low precision tensor datatype.
original_shape: Tuple[int, ...]
original shape of the tensor.
_quantizer: Quantizer
Builder class for quantized tensor.
"""
data: Optional[torch.Tensor] = None
scale: Optional[torch.Tensor] = None
data_t: Optional[torch.Tensor] = None
scale_t: Optional[torch.Tensor] = None
global_amax_row: Optional[torch.Tensor] = None
global_amax_col: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None
device: Optional[torch.device] = None
quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None
original_shape: Optional[Tuple[int, ...]] = None
_quantizer: Optional[Quantizer] = None
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware."""
return True
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the quantization result for saving for backward"""
tensors = [self.data, self.data_t, self.scale, self.scale_t]
self.data = None
self.data_t = None
self.scale = None
self.scale_t = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the quantization result from the saved tensors"""
self.data = tensors[0]
self.data_t = tensors[1]
self.scale = tensors[2]
self.scale_t = tensors[3]
return tensors[4:]
# Compatibility
@property
def _data(self):
return self.data
@_data.setter
def _data(self, value):
self.data = value
@property
def _scale_inv(self):
return self.scale
@_scale_inv.setter
def _scale_inv(self, value):
self.scale = value
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"dtype={self.dtype}, "
f"device={self.device}, "
f"quant_dtype={self.quant_dtype}, "
f"original_shape={self.original_shape}"
")"
)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""Generate or remove quantized data based on provided usage."""
has_data = self.data is not None
has_data_transpose = self.data_t is not None
needs_data = has_data
needs_data_transpose = has_data_transpose
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP4 data, even from FP4 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP4 data is required to generate FP4 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self.data = None
if not needs_data_transpose:
self.data_t = None
def _create_transpose(self):
"""Create transposed quantized tensor"""
if not self.data.is_contiguous():
self.data = self.data.contiguous()
self.data_t = self.data.t().contiguous()
self.scale_t = self.scale
def size(self, *args, **kwargs): # pylint: disable=unused-argument
"""Return the original tensor shape, not the internal packed data shape.
FP4 quantization packs two 4-bit values into each 8-bit value, which reduces
the second dimension by half. This method returns the logical shape that
users expect, not the internal packed storage shape.
"""
assert self.original_shape is not None
return torch.Size(self.original_shape)
def get_wgrad_sign_vector() -> torch.Tensor:
"""Hard-coded signs for Hadamard transform"""
return torch.tensor(
[1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0],
dtype=torch.float32,
)
class NVFP4QuantizerRef(Quantizer):
"""NVFP4 quantizer for middleware between Transformer Engine and Kitchen"""
def __init__(
self,
dtype: utils.Fp4Formats,
rowwise: bool = True,
columnwise: bool = True,
pow_2_scales: bool = False,
eps: float = 0.0,
quant_tile_shape: Tuple[int, int] = (1, 16),
with_rht: bool = False,
with_random_sign_mask: bool = True,
):
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.internal = True
self.dtype = dtype
self.pow_2_scales = pow_2_scales
self.eps = eps
self.quant_tile_shape = quant_tile_shape
self.with_rht = with_rht
self.with_random_sign_mask = with_random_sign_mask
@property
def experimental(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware"""
return True
@staticmethod
def _build_hadamard_matrix(
size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True
) -> torch.Tensor:
"""Construct a Hadamard matrix of given power-of-two size with entries +-1.
Uses Sylvester construction to avoid SciPy dependency.
"""
assert (size & (size - 1)) == 0, "Hadamard size must be a power of two"
h = torch.ones((1, 1), device=device, dtype=torch.float32)
while h.shape[0] < size:
h = torch.cat(
[
torch.cat([h, h], dim=1),
torch.cat([h, -h], dim=1),
],
dim=0,
)
if with_random_sign_mask:
sign_mat = get_wgrad_sign_vector().to(device) * torch.eye(
size, device=device, dtype=torch.float32
)
h = sign_mat @ h
return h.to(dtype)
def _apply_rht(self, x: torch.Tensor) -> torch.Tensor:
"""Apply randomized Hadamard transform without random signs (reference path).
This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))).
"""
# Only apply when enabled
if not self.with_rht:
return x
# RHT dimension equals the quantization tile length (NVFP4 uses 16)
rht_dim = self.quant_tile_shape[1]
assert (
x.shape[-1] % rht_dim == 0
), f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}"
# Build H and scale
H = self._build_hadamard_matrix(rht_dim, x.device, x.dtype, self.with_random_sign_mask)
scale = 1.0 / float(rht_dim) ** 0.5
# Perform blockwise transform along the last dimension
original_shape = x.shape
x_mat = x.contiguous().view(-1, rht_dim)
# Random sign matrix is identity in this reference (no sign flipping)
transform = H * scale
out = x_mat @ transform
return out.view(original_shape)
@staticmethod
def _recover_swizzled_scales(
swizzled_scale: bool, scale: torch.Tensor, m: int, n: int, block_length: int
) -> torch.Tensor:
if not swizzled_scale:
return scale
rounded_m = utils.roundup_div(m, 128) * 128
scale_n = utils.roundup_div(n, block_length)
rounded_n = utils.roundup_div(scale_n, 4) * 4
# Recover swizzled scaling factor layout -> linear layout
tmp = torch.reshape(scale, (rounded_m // 128, rounded_n // 4, 32, 4, 4))
# after permutation, the layout is [rounded_m // 128, 4, 32, rounded_n // 4, 4]
tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
result = torch.reshape(tmp, (rounded_m, rounded_n))
return result[:m, :scale_n]
@classmethod
def _quantize_blockwise_reference(
cls,
x: torch.Tensor,
global_amax: torch.Tensor,
tile_len_x: int,
tile_len_y: int,
*,
pow_2_scales: bool,
eps: float, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.ndim == 2
using_2d_quantization = tile_len_x == 16 and tile_len_y == 16
m, n = x.shape
# Compute vec_max based on the original x (before reshape)
# For 1D quantization: amax over each row chunk of 16
# For 2D quantization: amax over each 16x16 block, but output shape is still (128, 8, 1), filled with block amax
if using_2d_quantization:
# x shape: (128, 128)
x_blocks = (
x.unfold(0, tile_len_y, tile_len_y)
.unfold(1, tile_len_x, tile_len_x)
.to(torch.float32)
) # (8, 8, 16, 16)
block_amax = torch.amax(torch.abs(x_blocks), dim=(-1, -2)) # (8, 8)
# Now, expand to (128, 8, 1) by repeating each block_amax for 16 rows
vec_max = block_amax.repeat_interleave(tile_len_y, dim=0).unsqueeze(-1) # (128, 8, 1)
else:
# x shape: (128, 128)
x_reshaped = x.view(m, n // tile_len_x, tile_len_x) # (128, 8, 16)
vec_max = torch.amax(torch.abs(x_reshaped), dim=-1, keepdim=True).to(
torch.float32
) # (128, 8, 1)
x = x.view(m, n // tile_len_x, tile_len_x)
FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32)
FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32)
decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX)
if pow_2_scales:
decode_scale = cast_to_e8(decode_scale)
encode_scale = torch.div(
torch.tensor(1.0, device=x.device, dtype=torch.float32),
decode_scale.to(torch.float32),
)
else:
global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax)
global_encode_scale = torch.min(
global_encode_scale,
torch.tensor(
torch.finfo(torch.float32).max,
device=global_encode_scale.device,
dtype=torch.float32,
),
)
if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32):
global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32)
global_decode_scale = torch.div(1.0, global_encode_scale)
decode_scale = decode_scale * global_encode_scale
decode_scale = torch.min(
decode_scale,
torch.tensor(
torch.finfo(torch.float32).max,
device=decode_scale.device,
dtype=torch.float32,
),
)
decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX)
decode_scale = decode_scale.to(torch.float8_e4m3fn)
encode_scale = torch.min(
torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale),
torch.tensor(
torch.finfo(torch.float32).max,
device=decode_scale.device,
dtype=torch.float32,
),
)
scaled_x = x.to(torch.float32) * encode_scale
clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n)
return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1)
@staticmethod
def _pad_tensor(
tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int]
) -> torch.Tensor:
assert tensor.dim() == 2, "only supports 2D tensors"
M, N = tensor.shape
padding_needed_rows = 0
padding_needed_cols = 0
if row_divisor is not None and M % row_divisor != 0:
padding_needed_rows = row_divisor - (M % row_divisor)
# Check and calculate column padding if col_divisor is provided
if col_divisor is not None and N % col_divisor != 0:
padding_needed_cols = col_divisor - (N % col_divisor)
# Return original tensor if no padding is needed
if padding_needed_rows == 0 and padding_needed_cols == 0:
return tensor
# pad the tensor
out = torch.nn.functional.pad(
tensor,
(0, padding_needed_cols, 0, padding_needed_rows),
mode="constant",
value=0.0,
).contiguous()
return out
@staticmethod
def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torch.Tensor:
assert tensor.dim() == 2, "only supports 2D tensors"
M, N = original_size
out = tensor[:M, :N].contiguous()
return out
def _quantize(self, tensor: torch.Tensor) -> Tuple[
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
torch.Tensor,
]:
"""
Python implementation of microblock FP4 quantization.
Parameters
----------
tensor : torch.Tensor
Input tensor to quantize (should be 2D)
Returns
-------
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]
(qx, sx, qx_t, sx_t, global_amax) where:
- qx: quantized data in row-major order (if rowwise_usage), None otherwise
- sx: scale tensor for qx (if rowwise_usage), None otherwise
- qx_t: quantized data in column-major order (if columnwise_usage), None otherwise
- sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise
- global_amax_row, global_amax_col: global amax tensors
"""
if self.pow_2_scales:
assert self.quant_tile_shape == (
1,
32,
), "MXFP4 only supports 1x32 tile shape."
# TODO(etsykunov): Fix bug where global_amax_row and
# global_amax_col are not defined
# global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32)
else:
assert self.quant_tile_shape in (
(1, 16),
(16, 16),
), "NVFP4 only supports 1x16 or 16x16 tile shape."
# Prepare inputs once so we can reuse for both amax and quantization
# Row-input will always be the original input.
row_input = tensor
col_input = (
self._apply_rht(tensor.t().contiguous())
if self.with_rht
else tensor.t().contiguous()
)
# Compute amax for rowwise and columnwise paths separately
global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1)
global_amax_col = (
torch.max(torch.abs(col_input)).to(torch.float32).view(1)
if self.columnwise_usage
else global_amax_row
)
transpose_scales = False
M, N = tensor.shape
if self.rowwise_usage:
x_input = row_input
x_padded = self._pad_tensor(
x_input, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1]
)
qx, sx = self._quantize_blockwise_reference(
x_padded,
global_amax_row,
self.quant_tile_shape[1],
self.quant_tile_shape[0],
pow_2_scales=self.pow_2_scales,
eps=self.eps,
)
if transpose_scales:
sx = sx.T
qx = self._rm_pad_tensor(qx, (M, N // 2))
else:
qx = None
sx = None
if self.columnwise_usage:
x_t = col_input
x_t_padded = self._pad_tensor(
x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1]
)
qx_t, sx_t = self._quantize_blockwise_reference(
x_t_padded,
global_amax_col,
self.quant_tile_shape[1],
self.quant_tile_shape[0],
pow_2_scales=self.pow_2_scales,
eps=self.eps,
)
qx_t = self._rm_pad_tensor(qx_t, (N, M // 2))
if transpose_scales:
sx_t = sx_t.T
else:
qx_t = None
sx_t = None
return qx, sx, qx_t, sx_t, global_amax_row, global_amax_col
def quantize(
self,
tensor: torch.Tensor,
**kwargs, # pylint: disable=unused-argument
) -> NVFP4TensorRef:
# sanity checks
assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype."
# Make it work with 3D tensors
original_shape = tensor.shape
if tensor.ndim > 2:
tensor = tensor.view(-1, tensor.shape[-1])
qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(tensor)
return NVFP4TensorRef(
data=qx,
scale=sx,
data_t=qx_t,
scale_t=sx_t,
global_amax_row=global_amax_row,
global_amax_col=global_amax_col,
dtype=tensor.dtype,
device=tensor.device,
quant_dtype=self.dtype,
_quantizer=self,
original_shape=original_shape,
)
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensorStorage,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensorStorage:
"""Update the quantized tensor with the given tensor in-place
Parameters
----------
src: torch.Tensor
Source tensor to copy from
dst: QuantizedTensorStorage
Destination QuantizedTensorStorage to update
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
# Handle noop flag
if noop_flag is not None and noop_flag.item() != 0:
return dst
# Make sure input is in expected format
if not src.is_contiguous():
src = src.contiguous()
# Store the original shape and reshape for processing
original_shape = src.shape
if src.ndim > 2:
src = src.view(-1, src.shape[-1])
qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(src)
# Update the destination with new data
dst.data = qx
dst.scale = sx
dst.data_t = qx_t
dst.scale_t = sx_t
dst.global_amax_row = global_amax_row
dst.global_amax_col = global_amax_col
dst.dtype = src.dtype
dst.quant_dtype = self.dtype
dst.original_shape = original_shape
return dst
@property
def supports_allgather_fp8(self) -> bool:
"""Whether the tensor data can be all-gathered with an FP8 all-gather.
TODO(etsykunov): Confirm docstring is correct. Also, this API
seems too FP8-specific and should be reconsidered.
"""
return False
def transpose_qresult(self, qresult: QuantizedTensorStorage) -> QuantizedTensorStorage:
"""Convert row-wise data to column-wise data (?)
TODO(etsykunov): Confirm docstring is correct.
"""
raise NotImplementedError("Transpose qresult is not implemented for FP4.")
@property
def supports_dequantize(self) -> bool:
"""Whether quantized tensor can converted to high-precision tensor"""
return False
@property
def is_data_t_transposed_in_memory(self) -> bool:
"""Whether column-wise data is stored in transposed layout.
TODO(etsykunov): Confirm docstring is correct.
"""
raise NotImplementedError("Not implemented yet")
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
m_params: quantization.MMParams, # pylint: disable=unused-argument
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP,
qresult_x: QuantizedTensorStorage | None = None,
qresult_w: QuantizedTensorStorage | None = None,
) -> torch.Tensor:
"""Python implementation of microblock FP4 GEMM."""
assert bias is None, "Bias is implemented for FP4 GEMM."
high_precision_x = cast_from_fp4x2(qx, out_dtype)
high_precision_w = cast_from_fp4x2(qw, out_dtype)
if self.pow_2_scales:
if sx.dtype == torch.uint8:
# if scaling factor is stored in uint8 container
sx = torch.tensor(2.0, device=sx.device, dtype=torch.float32) ** (
(
sx.to(torch.float32)
- torch.tensor(127, device=sx.device, dtype=torch.float32)
)
)
sw = torch.tensor(2.0, device=sw.device, dtype=torch.float32) ** (
(
sw.to(torch.float32)
- torch.tensor(127, device=sw.device, dtype=torch.float32)
)
)
else:
# if scaling factor is torch.float8_e8m0fnu
sx = sx.to(torch.float32)
sw = sw.to(torch.float32)
alpha = torch.tensor(1.0, device=high_precision_x.device, dtype=torch.float32)
else:
assert qresult_x is not None
assert qresult_w is not None
assert qresult_x.global_amax_row is not None
assert qresult_w.global_amax_col is not None
sx = sx.to(torch.float32)
sw = sw.to(torch.float32)
factor = 6.0 * 6.0 * 448.0 * 448.0
if gemm_type == quantization.GEMMType.WGRAD:
partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col
else:
partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row
alpha = torch.div(partial_alpha, factor).squeeze(-1)
M, K = high_precision_x.shape
N, K_w = high_precision_w.shape
assert K == K_w, "K dimension mismatch between qx and qw"
assert K % 32 == 0, "K dimension must be divisible by 32"
assert N % 8 == 0, "N dimension must be divisible by 8"
block_length = 32 if self.pow_2_scales else 16
grid_k = K // block_length
assert sx.shape == (
M,
K // block_length,
), f"sx shape mismatch: expected ({M}, {K//block_length}), got {sx.shape}"
assert sw.shape == (
N,
K // block_length,
), f"sw shape mismatch: expected ({N}, {K//block_length}), got {sw.shape}"
y = torch.zeros(M, N, dtype=torch.float32, device=qx.device)
# below implementation is to match the FP4 tensor core implementation
# Each output element (i, j) is fp32 accumulation of (K // block_length) inner products
# Each inner product is sx * sw * (1, block_length) x (block_length, 1) with precision in fp32
# Then batch the computation in M, N dimension
for k in range(grid_k):
k_start = k * block_length
k_end = k_start + block_length
qx_block = high_precision_x[:, k_start:k_end].clone().contiguous()
qw_block = high_precision_w[:, k_start:k_end].clone().contiguous()
# Extract scaling factors for the current blocks
sx_block = sx[:, k]
sw_block = sw[:, k]
y += torch.outer(sx_block, sw_block) * high_precision_gemm_ref(
qx_block, qw_block, torch.float32, is_b_transposed=True
)
if not self.pow_2_scales and K > 0:
# only apply global scale for NVFP4 and non-empty cases
y = alpha * y
# accumulation happens at epilogue in float32
if accumulate:
assert out is not None, "Output tensor must be provided for accumulation."
y += out.to(torch.float32)
else:
assert out is None, "Output tensor should be None when accumulate is False."
y = y.to(out_dtype)
return y
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for experimental middleware between Transformer Engine and Kitchen."""
import enum
import torch
HIGH_PRECISION_FLOAT_DTYPES = (
torch.float,
torch.float16,
torch.bfloat16,
torch.float32,
)
class Fp4Formats(enum.Enum):
"""FP4 data format"""
E2M1 = "e2m1"
def roundup_div(x: int, y: int) -> int:
"""Round up division"""
assert x >= 0
assert y > 0
return (x + y - 1) // y
...@@ -4,6 +4,16 @@ ...@@ -4,6 +4,16 @@
"""Tensor class with FP8 data""" """Tensor class with FP8 data"""
import warnings
from .tensor.float8_tensor import Float8Tensor from .tensor.float8_tensor import Float8Tensor
warnings.warn(
"transformer_engine.pytorch.float8_tensor is deprecated and will be removed"
" in a future release. Float8Tensor should be imported directly through "
"`from transformer_engine.pytorch import Float8Tensor`",
DeprecationWarning,
stacklevel=2,
)
__all__ = ["Float8Tensor"] __all__ = ["Float8Tensor"]
...@@ -2,18 +2,26 @@ ...@@ -2,18 +2,26 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""FP8 utilities for TransformerEngine""" """
from __future__ import annotations DEPRECATED in favor of `transformer_engine.pytorch.quantization.py`.
"""
import abc # pylint: disable=wrong-import-position,unused-import
import itertools
import os
from contextlib import contextmanager
from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch import warnings
import transformer_engine_torch as tex
warnings.warn(
"Using deprecated internal API from Transformer Engine. "
"transformer_engine.pytorch.fp8 will be removed in a "
"future release.",
DeprecationWarning,
stacklevel=2,
)
# There are some users indirectly importing these classes
# from fp8.py. This ensure backwards compatibility.
# https://github.com/Lightning-AI/lightning-thunder/pull/2635.
from transformer_engine.common.recipe import ( from transformer_engine.common.recipe import (
Recipe, Recipe,
DelayedScaling, DelayedScaling,
...@@ -21,1082 +29,43 @@ from transformer_engine.common.recipe import ( ...@@ -21,1082 +29,43 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling, MXFP8BlockScaling,
Float8CurrentScaling, Float8CurrentScaling,
Float8BlockScaling, Float8BlockScaling,
NVFP4BlockScaling,
CustomRecipe,
) )
from .constants import dist_group_type # Importing each function instead of 'import *' allows us specify '__all__' in
from .utils import get_device_compute_capability # quantize.py and also makes any newer additions to quantize.py invisible via
from .jit import jit_fuser # fp8.py so that we don't reinforce importing internal TE functions.
from torch.utils.cpp_extension import IS_HIP_EXTENSION from .quantization import (
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) check_fp8_support,
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0"))) check_mxfp8_support,
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128")) check_nvfp4_support,
check_fp8_block_scaling_support,
__all__ = ["fp8_autocast", "fp8_model_init"] check_recipe_support,
get_default_fp8_recipe,
if IS_HIP_EXTENSION: get_fp8_torch_dtype,
from transformer_engine.pytorch.utils import is_K100_AI, is_BW get_fp8_te_dtype,
get_fp4_te_dtype,
def check_fp8_support() -> Tuple[bool, str]: get_fp8_max,
"""Return if fp8 support is available""" FP8GlobalStateManager,
if IS_HIP_EXTENSION: fp8_model_init,
if (is_K100_AI() or is_BW()) and int8_simulation_fp8: fp8_autocast,
return True, "DCU turn on fp8 simulation with int8" _update_amax_history,
else: _default_get_amax_and_update_history,
return False, "DCU not support fp8 for now" _default_sf_compute,
else: _compute_amax_and_update_history,
if get_device_compute_capability() >= (9, 0): # hopper and above _compute_scaling_factor,
return True, "" _amax_and_scale_update,
if get_device_compute_capability() < (8, 9): # pre-ada split_and_copy,
return False, "Device compute capability 8.9 or higher required for FP8 execution." RecipeState,
if tex.get_cublasLt_version() < 120103: DelayedScalingRecipeState,
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." Float8CurrentScalingRecipeState,
if float(torch.version.cuda) < 12.1: MXFP8BlockScalingRecipeState,
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." Float8BlockScalingRecipeState,
return True, "" NVFP4BlockScalingRecipeState,
CustomRecipeState,
int8_simulation_fp8,
def check_mxfp8_support() -> Tuple[bool, str]: int8_simulation_fp8_tensorwise,
"""Return if fp8 support is available""" blockwise_fp8_block_len
if get_device_compute_capability() >= (12, 0): )
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, ""
return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION:
if is_K100_AI() or is_BW():
return True, ""
else:
return False, "DCU not support block_scaling fp8 for now"
if (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
and float(torch.version.cuda) >= 12.9
):
return True, ""
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def check_recipe_support(recipe: Recipe) -> None:
"""Check if the given recipe is supported."""
recipe_supported = True
unsupported_reason = ""
if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)):
recipe_supported, unsupported_reason = check_fp8_support()
elif isinstance(recipe, Float8BlockScaling):
recipe_supported, unsupported_reason = check_fp8_block_scaling_support()
elif isinstance(recipe, MXFP8BlockScaling):
recipe_supported, unsupported_reason = check_mxfp8_support()
assert recipe_supported, unsupported_reason
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if check_mxfp8_support()[0]:
return MXFP8BlockScaling()
if get_device_compute_capability() >= (12, 0):
# This is a temporary restriction until MXFP8 is supported for all gemm layouts.
return Float8CurrentScaling()
return DelayedScaling()
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return torch.float8_e4m3fn
return torch.float8_e5m2
def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
"""Get max representible FP8 value."""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return Format.E4M3.value.max_fwd
return Format.E5M2.value.max_fwd
class FP8GlobalStateManager:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
"""
FP8_ENABLED = False
FP8_CALIBRATION = False
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
HIGH_PRECISION_INIT_VAL = False
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
FP8_AUTOCAST_DEPTH = 0
global_amax_buffer = {}
global_amax_history_buffer = {}
global_scale_buffer = {}
fp8_tensors_recompute_buffer = []
fp8_available = None
reason_for_no_fp8 = ""
autocast_arguments = {}
autocast_to_fp8_params = {}
fp8_param_to_autocast = {}
skip_fp8_weight_update_tensor = None
mxfp8_available = None
reason_for_no_mxfp8 = ""
fp8_block_scaling_available = None
reason_for_no_fp8_block_scaling = None
@classmethod
def reset(cls) -> None:
"""Reset the global state"""
cls.FP8_ENABLED = False
cls.FP8_CALIBRATION = False
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.FP8_PARAMETERS = False
cls.HIGH_PRECISION_INIT_VAL = False
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_GRAPH_CAPTURING = False
cls.FP8_AUTOCAST_DEPTH = 0
cls.global_amax_buffer = {}
cls.global_amax_history_buffer = {}
cls.global_scale_buffer = {}
cls.fp8_tensors_recompute_buffer = []
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
cls.skip_fp8_weight_update_tensor = None
cls.mxfp8_available = None
cls.reason_for_no_mxfp8 = ""
cls.fp8_block_scaling_available = None
cls.reason_for_no_fp8_block_scaling = ""
@classmethod
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
"""`skip_fp8_weight_update_tensor` inplace setter."""
if cls.skip_fp8_weight_update_tensor is None:
cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
cls.skip_fp8_weight_update_tensor.fill_(skip)
@classmethod
def get_skip_fp8_weight_update_tensor(cls) -> None:
"""`skip_fp8_weight_update_tensor` getter."""
return cls.skip_fp8_weight_update_tensor
@classmethod
def is_fp8_available(cls) -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if cls.fp8_available is None:
cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
return cls.fp8_available, cls.reason_for_no_fp8
@classmethod
def is_mxfp8_available(cls) -> Tuple[bool, str]:
"""Return if MXFP8/current scaling support is available."""
if cls.mxfp8_available is None:
cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support()
return cls.mxfp8_available, cls.reason_for_no_mxfp8
@classmethod
def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]:
"""Return if Float8 block scaling support is available."""
if cls.fp8_block_scaling_available is None:
cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = (
check_fp8_block_scaling_support()
)
return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling
@staticmethod
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
if forward:
return "scaling_fwd"
return "scaling_bwd"
@staticmethod
def get_fwd_bwd_key(forward: bool = True) -> str:
"""Convert bool `forward` to string."""
return "forward" if forward else "backward"
@classmethod
def get_buffer_info(cls) -> str:
"""
Returns a key for `fp8_meta` that stores the module's index
in the global buffers along with autocast information.
"""
return "buffer_index_and_autocast_key"
@classmethod
def get_key_in_buffer(
cls,
forward: bool,
fp8_recipe: Recipe,
fp8_group: dist_group_type,
) -> str:
"""Returns a key into the global FP8 buffers."""
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
fwd_bwd_key = cls.get_fwd_bwd_key(forward)
return f"{fwd_bwd_key}_{autocast_key}"
@classmethod
def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
"""Splits buffer key into relevant parts."""
forward, autocast_key = key.split("_", 1)
forward = forward == "forward"
return forward, autocast_key
@classmethod
def add_fp8_tensors_to_global_buffer(
cls,
fp8_meta: Dict[str, Any],
) -> None:
"""
Delayed scaling only.
The amax reduction process happens completely outside the FP8 modules.
To participate in the reduction, the only role played by a module is
to call this function in order to append it's FP8 tensor into a global
buffer. There are 5 global buffers maintained, one each for amax, amax
history, scale, scale-inverse, and non-weight-mask. Each buffer has
keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
to indicate the type of FP8 tensor, since the forward and backward
reductions happen separately.
Note: For CG capture, this method is called from the graphed
wrapper. For non CG case, it's called from within the module.
"""
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
# Every module must call this function exactly once since
# the amax tensors are static. Ensures that compatibility
# with non-graphed modules is maintained.
index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors.
if index_in_buffer in fp8_meta:
return
fp8_meta[index_in_buffer] = []
for forward in (True, False):
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])
if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history]
cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale]
else:
cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
cls.global_amax_history_buffer[key].append(
fp8_meta[fp8_meta_tensor_key].amax_history
)
cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
fp8_meta[index_in_buffer].append(key)
@classmethod
def is_fp8_enabled(cls) -> bool:
"""Is FP8 enabled"""
return cls.FP8_ENABLED
@classmethod
def is_fp8_calibration(cls) -> bool:
"""Is FP8 calibration"""
return cls.FP8_CALIBRATION
@classmethod
def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS
@classmethod
def with_high_precision_init_val(cls) -> bool:
"""Should the high precision initial values be stored with FP8 parameters"""
return cls.HIGH_PRECISION_INIT_VAL
@classmethod
def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()
@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
tmp = cls.IS_FIRST_FP8_MODULE
cls.IS_FIRST_FP8_MODULE = False
return tmp
@classmethod
def get_fp8_recipe(cls) -> Recipe:
"""Return the fp8 recipe"""
if cls.FP8_RECIPE is not None:
return cls.FP8_RECIPE
return get_default_fp8_recipe()
@classmethod
def get_fp8_group(cls) -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return cls.FP8_DISTRIBUTED_GROUP
@classmethod
def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]:
"""FP8 autocast state getter"""
return (
cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE,
cls.FP8_GRAPH_CAPTURING,
)
@classmethod
def set_fp8_autocast_state(
cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool]
) -> None:
"""FP8 autocast state setter"""
(
cls.FP8_ENABLED,
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE,
cls.FP8_GRAPH_CAPTURING,
) = fp8_state
@staticmethod
def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=False,
)
@classmethod
def reduce_and_update_fp8_tensors(
cls,
forward: bool = True,
) -> None:
"""Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
# global_amax_buffer should only be non-empty for fp8 delayed scaling
for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction.
fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
if fwd_update != forward:
continue
if len(amax_buffer) == 0:
continue
# Retrieve autocast specific args and concat amaxes.
recipe, group = cls.autocast_arguments[autocast_key]
contiguous_amax = torch.cat(amax_buffer)
# Reduction.
if (
recipe.reduce_amax
and torch.distributed.is_initialized()
and torch.distributed.get_world_size(group=group) > 1
):
cls.reduce_tensor_across_group_op_max(contiguous_amax, group)
# Amax and scale update.
unfused_update = (
bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0")))
or callable(recipe.amax_compute_algo)
or callable(recipe.scaling_factor_compute_algo)
)
if not unfused_update:
tex.fused_amax_and_scale_update_after_reduction(
contiguous_amax,
cls.global_amax_history_buffer[buffer_key],
cls.global_scale_buffer[buffer_key],
recipe.amax_compute_algo,
get_fp8_te_dtype(recipe, forward),
recipe.margin,
)
else:
split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer])
for amax_history, scale in zip(
cls.global_amax_history_buffer[buffer_key],
cls.global_scale_buffer[buffer_key],
):
_amax_and_scale_update(
amax_history, scale, get_fp8_max(recipe, forward), recipe
)
@classmethod
def get_unique_autocast_key(
cls,
recipe: Optional[Recipe] = None,
group: Optional[dist_group_type] = None,
):
"""
For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
Safely using `hash` as we never cross checkpoint boundaries.
"""
return f"{str(recipe)}:{hash(group)}"
@classmethod
def fp8_autocast_enter(
cls,
enabled: bool = False,
calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = fp8_recipe
cls.FP8_DISTRIBUTED_GROUP = fp8_group
cls.FP8_GRAPH_CAPTURING = _graph
if cls.FP8_AUTOCAST_DEPTH == 0:
cls.IS_FIRST_FP8_MODULE = True
cls.FP8_AUTOCAST_DEPTH += 1
if enabled:
fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
assert fp8_available, reason_for_no_fp8
if isinstance(fp8_recipe, MXFP8BlockScaling):
mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
assert mxfp8_available, reason_for_no_mxfp8
if isinstance(fp8_recipe, Float8BlockScaling):
fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available()
assert fp8_block_available, reason_for_no_fp8_block
@classmethod
def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
"""Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1
# Reduce only the non-FP8 weight modules here.
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
# delayed scaling only function, for other recipes (current scaling with any granularity),
# this is noop for other recipes because cls.global_amax_buffer is empty list
cls.reduce_and_update_fp8_tensors(forward=True)
@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same.
"""
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
to_copy = [
fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(),
]
if buffer_position_key in fp8_meta:
cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy)
else:
if len(cls.fp8_tensors_recompute_buffer) == 0:
cls.fp8_tensors_recompute_buffer = [deque()]
else:
cls.fp8_tensors_recompute_buffer.append(deque())
cls.fp8_tensors_recompute_buffer[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1
@classmethod
def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
"""
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone()
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone()
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
# delayed scaling only function, noop for any other recipe
if not fp8_meta["recipe"].delayed():
return
fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
@contextmanager
def fp8_model_init(
enabled: bool = True,
recipe: Optional[Recipe] = None,
preserve_high_precision_init_val: bool = False,
) -> None:
"""
Context manager for FP8 initialization of parameters.
Example usage:
.. code-block:: python
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
----------
enabled: bool, default = `True`
when enabled, Transformer Engine modules created inside this `fp8_model_init`
region will hold only FP8 copies of its parameters, as opposed to the default
behavior where both higher precision and FP8 copies are present. Setting this
option to `True` may result in lower memory consumption and is especially
useful for scenarios like:
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
* inference, where only the FP8 copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None`
Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize FP8 parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using FP8 parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
_fp8_recipe = FP8GlobalStateManager.FP8_RECIPE
_high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager.FP8_PARAMETERS = enabled
FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
try:
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val
@contextmanager
def fp8_autocast(
enabled: bool = True,
calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
"""
Context manager for FP8 usage.
.. code-block:: python
with fp8_autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
with shapes where both dimensions are divisible by 16. In terms of the input to the full
Transformer network, this typically requires padding sequence length to be multiple of 16.
.. note::
When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
inside a single `fp8_autocast` region. This is unsupported behavior because the amax
reduction is handled during the exit of the `fp8_autocast` context. Calling the same
module more than once inside an `fp8_autocast` region overrides the amax tensors
before reduction can occur.
Parameters
----------
enabled: bool, default = `True`
whether or not to enable fp8
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training
using a higher precision.
fp8_recipe: recipe.Recipe, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
if enabled:
check_recipe_support(fp8_recipe)
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(
enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=_graph,
)
try:
yield
finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1:
new_amax_history = torch.roll(amax_history, -1, 0)
amax_history.copy_(new_amax_history)
amax_history[0].fill_(0.0)
return amax_history
@torch.jit.script
def _default_get_amax_and_update_history(
amax_history: torch.Tensor,
amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Default function to obtain amax from history."""
if amax_compute_algo == "max":
amax = torch.max(amax_history, dim=0).values
else: # amax_compute_algo == "most_recent"
amax = amax_history[0].clone()
amax_history = _update_amax_history(amax_history)
return amax_history, amax
@jit_fuser
def _default_sf_compute(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
margin: int,
_fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter
) -> torch.Tensor:
"""Default function to convert amax to scaling factor.
Computing the scaling factor requires consideration of the following scenarios:
1. amax == 0:
No action is possible, set scale to the previous scale (or 1).
2. 0 < amax < tiny_amax
The amax is too tiny that the scale becomes infinite in FP32.
Set scale = FP32_max
3. tiny_amax <= amax < FP32_max:
Set scale = FP8_max (or scaled_max) / amax
4. When amax == inf or amax == nan:
No action is possible, set scale to the previous scale (or 1).
"""
sf = (fp8_max / amax) / (2**margin)
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf)
scale.copy_(sf)
return scale
def _compute_amax_and_update_history(
amax_history: torch.Tensor,
amax_compute_algo: Union[Callable, str],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Obtain the amax from the history."""
if callable(amax_compute_algo):
amax = amax_compute_algo(amax_history)
amax_history = _update_amax_history(amax_history)
return amax_history, amax
return _default_get_amax_and_update_history(
amax_history,
amax_compute_algo,
)
def _compute_scaling_factor(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
recipe: DelayedScaling,
) -> torch.Tensor:
"""Convert amax to scaling factor."""
if recipe.scaling_factor_compute_algo is None:
return _default_sf_compute(
amax,
scale,
fp8_max,
recipe.margin,
)
return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)
def _amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
recipe: DelayedScaling,
) -> None:
"""Updates FP8 meta tensors."""
new_amax_history, amax = _compute_amax_and_update_history(
amax_history,
recipe.amax_compute_algo,
)
new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe)
scale.copy_(new_scale)
amax_history.copy_(new_amax_history)
def split_and_copy(
buffer: torch.Tensor,
outputs: List[torch.Tensor],
chunk_sizes: List[int],
) -> None:
"""Split `buffer` by `chunk_sizes` and copy into `outputs`."""
splits = buffer.split(chunk_sizes)
torch._foreach_copy_(outputs, splits)
class RecipeState(abc.ABC):
"""Configuration and state for a quantization recipe.
This is a builder class for quantizers, which are in turn builder
classes for quantized tensors.
This class may pack together the state for multiple quantizers,
which is helpful for applying fused kernels with less overhead.
"""
@staticmethod
def create(
recipe: Recipe,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> RecipeState:
"""Factory method to create the state for a quantization recipe
Parameters
----------
recipe: Recipe
Quantization recipe.
mode: {"forward", "backward"}
Training stage where quantization will be performed.
num_quantizers: int, default = 1
Number of quantizers to create state for.
device: torch.device, default = default CUDA device
Device for quantized tensors.
Returns
-------
RecipeState:
Quantization recipe state.
"""
cls = None
if recipe.delayed():
cls = DelayedScalingRecipeState
elif recipe.mxfp8():
cls = MXFP8BlockScalingRecipeState
elif recipe.float8_current_scaling():
cls = Float8CurrentScalingRecipeState
elif recipe.float8_block_scaling():
cls = Float8BlockScalingRecipeState
else:
raise ValueError(f"{recipe.__class__.__name__} is not supported")
return cls(
recipe,
mode=mode,
num_quantizers=num_quantizers,
device=device,
)
@abc.abstractmethod
def make_quantizers(self) -> list:
"""Convert recipe state to quantizers.
Quantizers are builder classes for quantized tensors. They are
typically used to convert a high-precision tensor (e.g. in
FP32 or BF16) into a quantized tensor (e.g. in FP8).
"""
class DelayedScalingRecipeState(RecipeState):
"""State for FP8 quantization with per-tensor delayed scaling.
Delayed scaling recipe requires a scaling factor (applied when
casting to FP8) and a history of max-abs values ("amax") from
recent FP8 casts for updating the scaling factor. The scale update
is handled externally by `FP8GlobalStateManager`.
"""
recipe: DelayedScaling
mode: str
dtype: tex.DType
scale: torch.Tensor
amax_history: torch.Tensor
def __init__(
self,
recipe: DelayedScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp8_te_dtype(recipe, mode == "forward")
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device)
self.amax_history = torch.zeros(
recipe.amax_history_len,
num_quantizers,
dtype=torch.float32,
device=device,
)
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.float8_tensor import Float8Quantizer
return [
Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype)
for i in range(self.num_quantizers)
]
class Float8CurrentScalingRecipeState(RecipeState):
"""Configuration for Per-tensor current scaling quantization.
Per-tensor current quantization does not require state.
"""
recipe: Float8CurrentScaling
mode: str
dtype: tex.DType
device: torch.device
def __init__(
self,
recipe: Float8CurrentScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp8_te_dtype(recipe, mode == "forward")
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
from .tensor.float8_tensor import Float8CurrentScalingQuantizer
return [
Float8CurrentScalingQuantizer(self.dtype, device=self.device)
for i in range(self.num_quantizers)
]
class MXFP8BlockScalingRecipeState(RecipeState):
"""Configuration for MXFP8 quantization.
MXFP8 quantization does not require state.
"""
recipe: MXFP8BlockScaling
mode: str
dtype: tex.DType
def __init__(
self,
recipe: MXFP8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.dtype = get_fp8_te_dtype(recipe, mode == "forward")
# Allocate buffers
if device is None:
device = torch.device("cuda")
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.mxfp8_tensor import MXFP8Quantizer
return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)]
class Float8BlockScalingRecipeState(RecipeState):
"""Configuration for Float8BlockScaling quantization.
Float8BlockScaling quantization does not require state,
but different quantizers use different modes.
"""
recipe: Float8BlockScaling
mode: str
qx_dtype: tex.DType
qw_dtype: tex.DType
qgrad_dtype: tex.DType
def __init__(
self,
recipe: Float8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
self.qx_dtype = get_fp8_te_dtype(recipe, True)
self.qw_dtype = get_fp8_te_dtype(recipe, True)
self.qgrad_dtype = get_fp8_te_dtype(recipe, False)
# Allocate buffers
if device is None:
device = torch.device("cuda")
self.device = device
def make_quantizers(self) -> list:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
if self.mode == "forward":
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert self.num_quantizers % 3 == 0 # x, w, output per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qw_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale,
block_scaling_dim=self.recipe.w_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qx_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale,
block_scaling_dim=self.recipe.x_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 3)
]
)
)
assert self.mode == "backward", f"Unexpected mode {self.mode}"
assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm
return list(
itertools.chain.from_iterable(
[
[
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
Float8BlockQuantizer(
fp8_dtype=self.qgrad_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon,
force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale,
block_scaling_dim=self.recipe.grad_block_scaling_dim,
),
]
for _ in range(self.num_quantizers // 2)
]
)
)
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from collections.abc import Iterable from collections.abc import Iterable
import contextlib import contextlib
import gc import gc
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch import torch
...@@ -15,8 +16,8 @@ from torch._C import _graph_pool_handle ...@@ -15,8 +16,8 @@ from torch._C import _graph_pool_handle
from transformer_engine.common.recipe import DelayedScaling, Recipe from transformer_engine.common.recipe import DelayedScaling, Recipe
from transformer_engine.pytorch.constants import dist_group_type from transformer_engine.pytorch.constants import dist_group_type
from .fp8 import ( from .quantization import (
fp8_autocast, autocast,
FP8GlobalStateManager, FP8GlobalStateManager,
get_default_fp8_recipe, get_default_fp8_recipe,
) )
...@@ -84,7 +85,7 @@ def _make_graphed_callables( ...@@ -84,7 +85,7 @@ def _make_graphed_callables(
sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
num_warmup_iters: int = 3, num_warmup_iters: int = 3,
allow_unused_input: bool = False, allow_unused_input: bool = False,
fp8_weight_caching: bool = False, cache_quantized_params: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
_num_layers_per_chunk: Optional[List[int]] = None, _num_layers_per_chunk: Optional[List[int]] = None,
...@@ -252,7 +253,7 @@ def _make_graphed_callables( ...@@ -252,7 +253,7 @@ def _make_graphed_callables(
consumed_sample_q[sample_keys].append(per_callable_fwd_idx) consumed_sample_q[sample_keys].append(per_callable_fwd_idx)
fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:] fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:]
if fp8_weight_caching: if cache_quantized_params:
# Initialize flag that controls FP8 weight updates # Initialize flag that controls FP8 weight updates
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
...@@ -687,7 +688,7 @@ def _make_graphed_callables( ...@@ -687,7 +688,7 @@ def _make_graphed_callables(
# Decide whether to update FP8 weights # Decide whether to update FP8 weights
skip_fp8_weight_update = None skip_fp8_weight_update = None
if fp8_weight_caching: if cache_quantized_params:
assert "is_first_microbatch" in user_kwargs and isinstance( assert "is_first_microbatch" in user_kwargs and isinstance(
user_kwargs["is_first_microbatch"], bool user_kwargs["is_first_microbatch"], bool
), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching."
...@@ -796,14 +797,14 @@ def _make_graphed_callables( ...@@ -796,14 +797,14 @@ def _make_graphed_callables(
def save_fp8_tensors( def save_fp8_tensors(
modules: Iterable[torch.nn.Module], modules: Iterable[torch.nn.Module],
fp8_recipe: Optional[Recipe], recipe: Optional[Recipe],
) -> Optional[List[Any]]: ) -> Optional[List[Any]]:
""" """
Returns the FP8 tensors for all modules Returns the FP8 tensors for all modules
with adjusted amax history sizes. with adjusted amax history sizes.
""" """
if not isinstance(fp8_recipe, DelayedScaling): if not isinstance(recipe, DelayedScaling):
return None return None
fp8_tensors = [] fp8_tensors = []
...@@ -812,10 +813,10 @@ def save_fp8_tensors( ...@@ -812,10 +813,10 @@ def save_fp8_tensors(
module_tensors = None module_tensors = None
if isinstance(m, TransformerEngineBaseModule): if isinstance(m, TransformerEngineBaseModule):
if m.primary_weights_in_fp8: if m.primary_weights_in_fp8:
m.adjust_amax_history_length(fp8_recipe.amax_history_len) m.adjust_amax_history_length(recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors() module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation): elif isinstance(m, BasicOperation):
m.reset_recipe_state(recipe=fp8_recipe) m.reset_recipe_state(recipe=recipe)
module_tensors = m._save_fp8_metas() module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors) fp8_tensors.append(module_tensors)
return fp8_tensors return fp8_tensors
...@@ -850,11 +851,16 @@ def make_graphed_callables( ...@@ -850,11 +851,16 @@ def make_graphed_callables(
num_warmup_iters: int = 3, num_warmup_iters: int = 3,
allow_unused_input: bool = False, allow_unused_input: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: SingleOrTuple[bool] = False, fp8_enabled: Optional[SingleOrTuple[bool]] = None,
fp8_calibrating: bool = False, fp8_calibrating: Optional[bool] = None,
fp8_recipe: Optional[Recipe] = None, fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False, fp8_weight_caching: Optional[bool] = None,
enabled: Optional[SingleOrTuple[bool]] = None,
calibrating: Optional[bool] = None,
recipe: Optional[Recipe] = None,
amax_reduction_group: Optional[dist_group_type] = None,
cache_quantized_params: Optional[bool] = None,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
_num_layers_per_chunk: Optional[List[int]] = None, _num_layers_per_chunk: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None, pool: Optional[Tuple[int, ...]] = None,
...@@ -870,6 +876,11 @@ def make_graphed_callables( ...@@ -870,6 +876,11 @@ def make_graphed_callables(
`original PyTorch implementation <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_ `original PyTorch implementation <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
for more documentation. for more documentation.
.. warning::
Arguments 'fp8_enabled', 'fp8_calibrating', 'fp8_recipe', 'fp8_group', and 'fp8_weight_caching' are deprecated.
Use arguments 'enabled', 'calibrating', 'recipe', 'amax_reduction_group', and 'cache_quantized_params' instead.
Graphing parameters Graphing parameters
------------------- -------------------
modules: (tuple of) callable modules: (tuple of) callable
...@@ -894,30 +905,110 @@ def make_graphed_callables( ...@@ -894,30 +905,110 @@ def make_graphed_callables(
when `_order` is provided. All callables in `modules` are assumed to have when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape. inputs and outputs with the same dtype and shape.
FP8-related parameters Quantization related parameters
---------------------- ----------------------
fp8_enabled: (tuple of) bool, default = `False` enabled: (tuple of) bool, default = `False`
whether or not to enable fp8. whether or not to enable low precision quantization (FP8/FP4).
If tuple, the length must match the number of modules. If tuple, the length must match the number of modules.
fp8_calibrating: bool, default = `False` calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is data of quantized tensors even when executing without quantization enabled.
useful for saving an inference ready fp8 checkpoint while training This is useful for saving an inference ready checkpoint while training
using a higher precision. using a higher precision.
fp8_recipe: Recipe, default = `None` recipe: recipe.Recipe, default = `None`
recipe used for FP8 training. recipe used for low precision quantization.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step. are reduced at the end of each training step.
fp8_weight_caching: bool, default = `False` cache_quantized_params: bool, default = `False`
Whether or not to cache FP8 weights across microbatches. if set to `True`, Whether or not to cache quantized weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward the `is_first_microbatch` boolean argument must be passed into the forward
method for TransformerEngine modules. When storing primary weights in FP8 method for TransformerEngine modules. When storing primary weights in low precision
using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg using TE's `quantized_model_init` API and using an quantization aware optimizer,
must be set to `False` if calculating weight transposes' outside TE, e.g., this arg must be set to `False` if calculating weight transposes' outside TE, e.g.,
in the optimizer step. in the optimizer step.
""" """
# Handle deprecated args. If old kwargs are set, they are prioritized with warning.
if fp8_enabled is not None:
if enabled is not None:
raise ValueError(
"make_graphed_callables has deprecated `fp8_enabled` kwarg "
"in favor of `enabled`, but both kwargs are set."
)
warnings.warn(
"make_graphed_callables has deprecated `fp8_enabled` kwarg in favor of `enabled`. "
"`fp8_enabled` will be removed in a future release.",
category=DeprecationWarning,
stacklevel=2,
)
enabled = fp8_enabled
if enabled is None:
enabled = False
if fp8_calibrating is not None:
if calibrating is not None:
raise ValueError(
"make_graphed_callables has deprecated `fp8_calibrating` kwarg "
"in favor of `calibrating`, but both kwargs are set."
)
warnings.warn(
"make_graphed_callables has deprecated `fp8_calibrating` kwarg in favor of "
"`calibrating`. `fp8_calibrating` will be removed in a future release.",
category=DeprecationWarning,
stacklevel=2,
)
calibrating = fp8_calibrating
if calibrating is None:
calibrating = False
if fp8_recipe is not None:
if recipe is None:
warnings.warn(
"make_graphed_callables has deprecated `fp8_recipe` kwarg in favor of "
"`recipe`. `fp8_recipe` will be removed in a future release.",
category=DeprecationWarning,
stacklevel=2,
)
else:
raise ValueError(
"make_graphed_callables has deprecated `fp8_recipe` kwarg "
"in favor of `recipe`, but both kwargs are set."
)
recipe = fp8_recipe
if fp8_group is not None:
if amax_reduction_group is None:
warnings.warn(
"make_graphed_callables has deprecated `fp8_group` kwarg in favor of "
"`amax_reduction_group`. `fp8_group` will be removed in a future release.",
category=DeprecationWarning,
stacklevel=2,
)
else:
raise ValueError(
"make_graphed_callables has deprecated `fp8_group` kwarg "
"in favor of `amax_reduction_group`, but both kwargs are set."
)
amax_reduction_group = fp8_group
if fp8_weight_caching is not None:
if cache_quantized_params is not None:
raise ValueError(
"make_graphed_callables has deprecated `fp8_weight_caching` kwarg "
"in favor of `cache_quantized_params`, but both kwargs are set."
)
warnings.warn(
"make_graphed_callables has deprecated `fp8_weight_caching` kwarg in favor of "
"`cache_quantized_params`. `fp8_weight_caching` will be removed in a future release.",
category=DeprecationWarning,
stacklevel=2,
)
cache_quantized_params = fp8_weight_caching
if cache_quantized_params is None:
cache_quantized_params = False
set_capture_start() set_capture_start()
# Handle single module. # Handle single module.
...@@ -926,21 +1017,21 @@ def make_graphed_callables( ...@@ -926,21 +1017,21 @@ def make_graphed_callables(
just_one_callable = True just_one_callable = True
modules = (modules,) modules = (modules,)
if not isinstance(fp8_enabled, tuple): if not isinstance(enabled, tuple):
assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools" assert isinstance(enabled, bool), "enabled must be a bool or a tuple of bools"
fp8_enabled = (fp8_enabled,) * len(modules) enabled = (enabled,) * len(modules)
else: else:
assert len(fp8_enabled) == len( assert len(enabled) == len(
modules modules
), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})" ), f"enabled length ({len(enabled)}) must match modules length ({len(modules)})"
if any(fp8_enabled) and fp8_recipe is None: if any(enabled) and recipe is None:
fp8_recipe = get_default_fp8_recipe() recipe = get_default_fp8_recipe()
elif not any(fp8_enabled): elif not any(enabled):
fp8_recipe = None recipe = None
module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled)) module_uses_fp8 = dict(zip((id(m) for m in modules), enabled))
# Store FP8 tensors to reset later. # Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) saved_fp8_tensors = save_fp8_tensors(modules, recipe=recipe)
# FP8 wrapper. # FP8 wrapper.
old_call_funcs = {} old_call_funcs = {}
...@@ -954,11 +1045,11 @@ def make_graphed_callables( ...@@ -954,11 +1045,11 @@ def make_graphed_callables(
# Wrap the original call function of the module class. # Wrap the original call function of the module class.
def call_func(self, *args, **kwargs): def call_func(self, *args, **kwargs):
with fp8_autocast( with autocast(
enabled=module_uses_fp8.get(id(self), False), enabled=module_uses_fp8.get(id(self), False),
calibrating=fp8_calibrating, calibrating=calibrating,
fp8_recipe=fp8_recipe, recipe=recipe,
fp8_group=fp8_group, amax_reduction_group=amax_reduction_group,
_graph=True, _graph=True,
): ):
outputs = old_call_funcs[block_cls](self, *args, **kwargs) outputs = old_call_funcs[block_cls](self, *args, **kwargs)
...@@ -992,7 +1083,7 @@ def make_graphed_callables( ...@@ -992,7 +1083,7 @@ def make_graphed_callables(
sample_args, sample_args,
num_warmup_iters=num_warmup_iters, num_warmup_iters=num_warmup_iters,
allow_unused_input=allow_unused_input, allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching, cache_quantized_params=cache_quantized_params,
sample_kwargs=sample_kwargs, sample_kwargs=sample_kwargs,
_order=_order, _order=_order,
_num_layers_per_chunk=_num_layers_per_chunk, _num_layers_per_chunk=_num_layers_per_chunk,
......
...@@ -4,16 +4,17 @@ ...@@ -4,16 +4,17 @@
"""Internal function used by multiple modules.""" """Internal function used by multiple modules."""
from typing import Any, List, Optional, Tuple, Union, Callable import dataclasses
from dataclasses import dataclass
import queue import queue
from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_default_init_method
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..utils import get_default_init_method
import warnings import warnings
try: try:
from lightop import rmsnorm_forward,rmsnorm_backward from lightop import rmsnorm_forward,rmsnorm_backward
...@@ -179,7 +180,7 @@ def noop_cat( ...@@ -179,7 +180,7 @@ def noop_cat(
return _NoopCatFunc.apply(dim, *tensors) return _NoopCatFunc.apply(dim, *tensors)
@dataclass @dataclasses.dataclass
class _ParameterInitMeta: class _ParameterInitMeta:
""" """
Stores essential metadata needed to support deferred parameter initialization. Stores essential metadata needed to support deferred parameter initialization.
......
...@@ -22,11 +22,12 @@ import transformer_engine_torch as tex ...@@ -22,11 +22,12 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from ._common import _ParameterInitMeta, noop_cat from ._common import _ParameterInitMeta, noop_cat
from ..fp8 import ( from ..quantization import (
MXFP8BlockScalingRecipeState, MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState, DelayedScalingRecipeState,
Float8CurrentScalingRecipeState, Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState, Float8BlockScalingRecipeState,
NVFP4BlockScalingRecipeState,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
) )
...@@ -37,14 +38,14 @@ from ..distributed import ( ...@@ -37,14 +38,14 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
...@@ -82,7 +83,8 @@ def get_cublas_workspace_size_bytes() -> None: ...@@ -82,7 +83,8 @@ def get_cublas_workspace_size_bytes() -> None:
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
return 134_217_728 return 134_217_728
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
return 33_554_432 # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return 32 * 1024 * 1024 + 1024
return 4_194_304 return 4_194_304
...@@ -547,7 +549,7 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -547,7 +549,7 @@ def fill_userbuffers_buffer_for_all_gather(
local_tensor: torch.Tensor, local_tensor: torch.Tensor,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
process_group, process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]: ) -> tuple[torch.Tensor | QuantizedTensorStorage, torch.Tensor | QuantizedTensorStorage]:
"""Fill local shard of Userbuffers buffer with data for all-gather """Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the Returns the full tensor and the local shard, both using the
...@@ -571,7 +573,7 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -571,7 +573,7 @@ def fill_userbuffers_buffer_for_all_gather(
# Unquantized data # Unquantized data
if quantizer is None: if quantizer is None:
if isinstance(local_tensor, QuantizedTensorBase): if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor = local_tensor.dequantize() local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf(): if comm.is_fp8_ubuf():
raise RuntimeError( raise RuntimeError(
...@@ -584,8 +586,8 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -584,8 +586,8 @@ def fill_userbuffers_buffer_for_all_gather(
# FP8 data # FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorBase): if not isinstance(local_tensor, Float8TensorStorage):
if isinstance(local_tensor, QuantizedTensorBase): if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor.dequantize() local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor) local_tensor = quantizer(local_tensor)
...@@ -596,7 +598,7 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -596,7 +598,7 @@ def fill_userbuffers_buffer_for_all_gather(
) )
comm.copy_into_buffer(local_tensor._data, local_chunk=True) comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape) global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorBase( global_tensor = Float8TensorStorage(
data=global_tensor_data, data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv, fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype, fp8_dtype=local_tensor._fp8_dtype,
...@@ -608,8 +610,8 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -608,8 +610,8 @@ def fill_userbuffers_buffer_for_all_gather(
if isinstance(quantizer, MXFP8Quantizer): if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed # Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorBase): if not isinstance(local_tensor, MXFP8TensorStorage):
if isinstance(local_tensor, QuantizedTensorBase): if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor.dequantize() local_tensor.dequantize()
local_tensor = quantizer(local_tensor) local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf(): if not comm.is_fp8_ubuf():
...@@ -664,7 +666,7 @@ def fill_userbuffers_buffer_for_all_gather( ...@@ -664,7 +666,7 @@ def fill_userbuffers_buffer_for_all_gather(
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else: else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorBase( global_tensor = MXFP8TensorStorage(
rowwise_data=rowwise_data, rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv, rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data, columnwise_data=columnwise_data,
...@@ -802,6 +804,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -802,6 +804,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe_state, Float8BlockScalingRecipeState recipe_state, Float8BlockScalingRecipeState
): ):
return return
if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState):
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd # 2 (grad_output and grad_input) for bwd
...@@ -826,10 +830,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -826,10 +830,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
f"({len(weight_quantizers)}) must match" f"({len(weight_quantizers)}) must match"
) )
for weight, quantizer in zip(weight_tensors, weight_quantizers): for weight, quantizer in zip(weight_tensors, weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase): if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_quantizer(quantizer) weight.update_quantizer(quantizer)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module.""" """Get the weight tensors of the module."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_tensors function" f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
...@@ -1011,12 +1015,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1011,12 +1015,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return return
dtype = inp.dtype dtype = inp.dtype
for name, param in self.named_parameters(): if not self.allow_different_data_and_param_types:
if param is not None: for name, param in self.named_parameters():
assert dtype == param.dtype, ( if param is not None:
"Data types for parameters must match when outside of autocasted region. " assert dtype == param.dtype, (
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" "Data types for parameters must match when outside of autocasted region. "
) f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype self.activation_dtype = dtype
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
...@@ -1077,8 +1082,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1077,8 +1082,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe # Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd if hasattr(self.fp8_meta["recipe"], "fp8_format"):
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes # Allocate scales and amaxes
self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
...@@ -1105,6 +1111,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1105,6 +1111,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
inp: torch.Tensor, inp: torch.Tensor,
num_gemms: int = 1, num_gemms: int = 1,
allow_non_contiguous: bool = False, allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD. """Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know The context manager is needed because there isn't a way for a module to know
...@@ -1112,6 +1119,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1112,6 +1119,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one. just in case. The autocast exit will pick up the most recent one.
""" """
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
...@@ -1207,9 +1215,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1207,9 +1215,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output, grad_output,
( (
QuantizedTensor, QuantizedTensor,
Float8TensorBase, Float8TensorStorage,
MXFP8TensorBase, MXFP8TensorStorage,
Float8BlockwiseQTensorBase, Float8BlockwiseQTensorStorage,
), ),
): ):
grad_output = quantizer(grad_output) grad_output = quantizer(grad_output)
...@@ -1238,9 +1246,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1238,9 +1246,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_.get_tensor(True), grad_output_.get_tensor(True),
( (
QuantizedTensor, QuantizedTensor,
Float8TensorBase, Float8TensorStorage,
MXFP8TensorBase, MXFP8TensorStorage,
Float8BlockwiseQTensorBase, Float8BlockwiseQTensorStorage,
), ),
) )
and ctx.use_bias and ctx.use_bias
...@@ -1256,7 +1264,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1256,7 +1264,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ctx.use_bias: if ctx.use_bias:
if isinstance( if isinstance(
grad_output, grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), (
QuantizedTensor,
Float8TensorStorage,
MXFP8TensorStorage,
Float8BlockwiseQTensorStorage,
),
): ):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
...@@ -1265,10 +1278,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1265,10 +1278,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else: else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance( if not isinstance(grad_output, QuantizedTensorStorage):
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
):
grad_output = quantizer(grad_output) grad_output = quantizer(grad_output)
return grad_output, grad_bias return grad_output, grad_bias
...@@ -1422,14 +1432,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1422,14 +1432,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Reset cache if workspace is invalid # Reset cache if workspace is invalid
if out is not None and quantizer is not None: if out is not None and quantizer is not None:
reset_cache = False reset_cache = False
if isinstance(out, Float8TensorBase): if isinstance(out, Float8TensorStorage):
if ( if (
not is_non_tn_fp8_gemm_supported() not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage and quantizer.columnwise_usage
and out._transpose is None and out._transpose is None
): ):
reset_cache = True reset_cache = True
elif isinstance(out, MXFP8TensorBase): elif isinstance(out, MXFP8TensorStorage):
if quantizer.rowwise_usage and out._rowwise_data is None: if quantizer.rowwise_usage and out._rowwise_data is None:
reset_cache = True reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None: elif quantizer.columnwise_usage and out._columnwise_data is None:
...@@ -1609,8 +1619,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1609,8 +1619,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
- MXFP8BlockScaling → MXFP8Tensor - MXFP8BlockScaling → MXFP8Tensor
- Float8BlockScaling → Float8BlockTensor - Float8BlockScaling → Float8BlockTensor
Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), Example case to check: recipe is DelayedScaling (DelayedScaling is set in autocast()),
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in quantized_model_init()).
""" """
if not self.fp8 and not self.fp8_calibration: if not self.fp8 and not self.fp8_calibration:
return return
...@@ -1620,7 +1630,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1620,7 +1630,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe = self.fp8_meta["recipe"] recipe = self.fp8_meta["recipe"]
weight_tensors = [getattr(self, name) for name in self.weight_names] weight_tensors = [getattr(self, name) for name in self.weight_names]
for i, tensor in enumerate(weight_tensors): for i, tensor in enumerate(weight_tensors):
if isinstance(tensor, QuantizedTensorBase): if isinstance(tensor, QuantizedTensorStorage):
quantizer = tensor._get_quantizer() quantizer = tensor._get_quantizer()
if quantizer is None: if quantizer is None:
continue continue
...@@ -1631,6 +1641,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1631,6 +1641,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise RuntimeError( raise RuntimeError(
f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
" Please check the recipes assigned during fp8_model_init() and" " Please check the recipes assigned during quantized_model_init() and"
" fp8_autocast() calls." " autocast() calls."
) )
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..fp8 import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..fp8 import FP8GlobalStateManager from ..quantization import FP8GlobalStateManager
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment