Unverified Commit 3f5b4754 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[Core][PyTorch] NVFP4 recipe (#2177)



* Add NVFP4 recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add MathDx dependency to GitHub builds
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suggestions from GitHub Copilot
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move 2x shape logic from core to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compilation errors with CUDA 12.1
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* SM 70 is not supported in CUDA 13
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Typo
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Revert "Move 2x shape logic from core to PyTorch"

This reverts commit f8b2a2d0111d9af690b43bb98ae448d9a430a185.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Added dequantize kernel for FP4
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 support with fusible ops

Use logical tensor dims for PyTorch NVFP4 tensors. Temporarily add unfused dequantize impl. Fix bug where NVFP4 recipe was not configurable.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix logic for 2x shapes and move to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CG test model config
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug NVFP4 tensor size function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Proper handling of the RNG state
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Test SR properly
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix workspace size for GEMM heuristic.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compile error in C++ NVFP4 test

Some some numeric errors when blocks are all zero.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix distrbuted test problem shape
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* proper assert dim for low precision AG TP
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up duplicated code in nvfp4_utils.cuh
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pylint: disable=unused-argument
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* `nvte_cublas_gemm_v2` to take alpha pointer (#12)

* make nvte_cublas_gemm_v2 to take alpha/beta pointers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* users are expected to pass a valid C_tensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* typos
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* API to have const float* alpha
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Minor tweaks

Support arbitrary beta scales. Increase workspace to be aligned to 128 bytes.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug IMA with alpha pointer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Support fused amax kernels with NVFP4 quantization
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable fused amax with cuDNN LayerNorm kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 cases to distributed tests for TE ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change assert to NVTE_CHECK in the hadamard cast fusion
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix compile error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use global thread IDs for Philox subsequences
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add shape checks for NVFP4 cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Do not fuse amax if cuDNN normalization is forced by envvar
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dfeef1a2
...@@ -28,7 +28,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -28,7 +28,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const int multiprocessorCount, const bool zero_centered_gamma, const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) { cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp_scaling(z->scaling_mode)) { !is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
} }
...@@ -63,7 +63,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -63,7 +63,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
NVTE_Norm_Backend norm_backend; NVTE_Norm_Backend norm_backend;
bool is_aligned = true; bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
NVTE_CHECK(!cudnn_backend, NVTE_CHECK(!cudnn_backend,
......
...@@ -24,7 +24,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -24,7 +24,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) { const bool zero_centered_gamma, cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp_scaling(z->scaling_mode)) { !is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
} }
...@@ -49,7 +49,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -49,7 +49,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
NVTE_Norm_Backend norm_backend; NVTE_Norm_Backend norm_backend;
bool is_aligned = true; bool is_aligned = true;
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
NVTE_CHECK(!cudnn_backend, NVTE_CHECK(!cudnn_backend,
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
"""This module provides predefined FP8 recipes.""" """This module provides predefined FP8 recipes."""
from __future__ import annotations from __future__ import annotations
import warnings
import os import os
from enum import Enum from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple from typing import Literal, Optional, Union, Callable, NamedTuple
...@@ -23,9 +22,12 @@ class _FormatHelper(NamedTuple): ...@@ -23,9 +22,12 @@ class _FormatHelper(NamedTuple):
class Format(Enum): class Format(Enum):
""" """
Supported FP8 formats. Supported FP8 formats.
Supported FP4 formats.
Values Values
------ ------
E2M1 :
All FP4 tensors are in e2m1 format
E4M3 : E4M3 :
All FP8 tensors are in e4m3 format All FP8 tensors are in e4m3 format
E5M2 : E5M2 :
...@@ -35,6 +37,7 @@ class Format(Enum): ...@@ -35,6 +37,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format FP8 tensors in the backward pass are in e5m2 format
""" """
E2M1 = _FormatHelper(max_fwd=6, max_bwd=6)
E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) E4M3 = _FormatHelper(max_fwd=448, max_bwd=448)
E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344)
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
...@@ -42,9 +45,13 @@ class Format(Enum): ...@@ -42,9 +45,13 @@ class Format(Enum):
@dataclass(frozen=True) @dataclass(frozen=True)
class MMParams: class MMParams:
"""for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator) """Matrix multiplication options.
apply split accumulator or not, turning it on will increase accuracy but impact gemm performance,
so only turn it on for certain gemms Parameters
----------
use_split_accumulator : bool, default = `True`
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
""" """
use_split_accumulator: bool = True use_split_accumulator: bool = True
...@@ -55,10 +62,24 @@ class QParams: ...@@ -55,10 +62,24 @@ class QParams:
"""Quantization parameters. """Quantization parameters.
power_2_scale: use power of 2 scale parameter power_2_scale: use power of 2 scale parameter
amax_epsilon: optional minimum value of abs max amax_epsilon: optional minimum value of abs max
random_hadamard_transform: whether to use random hadamard transform
stochastic_rounding: whether to use stocastic rounding
""" """
power_2_scale: bool = False power_2_scale: bool = False
amax_epsilon: float = 0.0 amax_epsilon: float = 0.0
random_hadamard_transform: bool = False
stochastic_rounding: bool = False
fp4_2d_quantization: bool = False
def __repr__(self) -> str:
return (
f"Qparams(\npower_2_scale={self.power_2_scale},\n"
f"amax_epsilon={self.amax_epsilon},\n"
f"random_hadamard_transform={self.random_hadamard_transform},\n"
f"stochastic_rounding={self.stochastic_rounding},\n"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)"
)
class Recipe: class Recipe:
...@@ -66,6 +87,10 @@ class Recipe: ...@@ -66,6 +87,10 @@ class Recipe:
Base recipe class. Base recipe class.
""" """
def nvfp4(self):
"""Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling)
def mxfp8(self): def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling.""" """Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling) return isinstance(self, MXFP8BlockScaling)
...@@ -351,3 +376,84 @@ class Float8BlockScaling(Recipe): ...@@ -351,3 +376,84 @@ class Float8BlockScaling(Recipe):
f"fp8_dpa={self.fp8_dpa}, " f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}" f"fp8_mha={self.fp8_mha}"
) )
@dataclass()
class NVFP4BlockScaling(Recipe):
"""
Use the NVFP4 scaling strategy.
This is a 2-level block scaling strategy. In level 1, each group of
16 consecutive values is scaled together using their own scaling
factor. The type of the scaling factor is E4M3 (4 bits of exponent,
3 bits of mantissa). In level 2, a global per tensor FP32 scaling
factor is used to scale the entire tensor.
Since the scaling happens in a particular direction (either rowwise
or columnwise), in this recipe the quantized tensor and its transpose
are not numerically equivalent. Due to this, when Transformer Engine
needs both the tensor and its transpose (e.g. to calculate both
forward and backward pass), during the quantization both versions are
computed from the high precision input to avoid double quantization
errors.
Parameters
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
fp8_format : {Format.E4M3}, default = Format.E4M3
FP8 data type. Only E4M3 is supported.
fp8_dpa: bool, default = `False`
FP8 dot product attention. Not yet supported.
fp8_mha: bool, default = `False`
FP8 multi-head attention. Not yet supported.
"""
# Configuration envvars
disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1"
disable_stochastic_rounding: bool = (
os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1"
)
disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1"
fp4_format: Format = Format.E2M1
fp8_format: Format = Format.E4M3
# Not applying quantization to attention for now
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling"
assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling"
# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
# it can be used for wgrad GEMM.
self.fp4_quant_fwd_inp = QParams(
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=False,
fp4_2d_quantization=False,
)
self.fp4_quant_fwd_weight = QParams(
random_hadamard_transform=False,
stochastic_rounding=False,
fp4_2d_quantization=not self.disable_2d_quantization,
)
self.fp4_quant_bwd_grad = QParams(
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=not self.disable_stochastic_rounding,
fp4_2d_quantization=False,
)
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"fp4_format={str(self.fp4_format).split('.')[1]}, "
f"fp8_format={str(self.fp8_format).split('.')[1]}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}, "
f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, "
f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, "
f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, "
)
...@@ -20,6 +20,13 @@ namespace { ...@@ -20,6 +20,13 @@ namespace {
constexpr int amax_kernel_threads = 512; constexpr int amax_kernel_threads = 512;
__launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
*amax_ptr = 0;
}
template <int nvec, bool aligned, typename InputType> template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__ __launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N, void amax_kernel(const InputType *input, float *amax, const size_t N,
...@@ -65,7 +72,8 @@ template <int nvec, typename InputType> ...@@ -65,7 +72,8 @@ template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr, void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr,
cudaStream_t stream) { cudaStream_t stream) {
// Zero out amax so we can update with atomic max // Zero out amax so we can update with atomic max
NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); zero_amax_kernel<<<1, 1, 0, stream>>>(amax, noop_ptr);
NVTE_CHECK_CUDA(cudaGetLastError());
// Return immediately if tensor is empty // Return immediately if tensor is empty
if (N == 0) { if (N == 0) {
...@@ -130,15 +138,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt ...@@ -130,15 +138,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor // Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *convertNVTETensorCheck(output_); auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING ||
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " output.scaling_mode == NVTE_NVFP4_1D_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling or "
"NVFP4 1D scaling, "
"but got scaling_mode=", "but got scaling_mode=",
to_string(output.scaling_mode)); to_string(output.scaling_mode));
NVTE_CHECK(output.amax.numel() == 1, NVTE_CHECK(output.amax.numel() == 1,
"Output tensor for amax computation has invalid amax tensor " "Output tensor for amax computation has invalid amax tensor "
"(expected 1 entry, got shape=", "(expected 1 entry, got shape=",
output.amax.shape, ")"); output.amax.shape, ")");
NVTE_CHECK(output.amax.dptr != nullptr, NVTE_CHECK(output.amax.dptr != nullptr || output.columnwise_amax.dptr != nullptr,
"Output tensor for amax computation has amax tensor without data"); "Output tensor for amax computation has amax tensor without data");
NVTE_CHECK(output.amax.dtype == DType::kFloat32, NVTE_CHECK(output.amax.dtype == DType::kFloat32,
"Output tensor for amax computation has invalid amax tensor " "Output tensor for amax computation has invalid amax tensor "
...@@ -157,11 +167,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt ...@@ -157,11 +167,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
} }
// Compute amax // Compute amax
float *amax_ptr = reinterpret_cast<float *>(
(output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel<nvec>(
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr), amax_ptr, input.data.numel(), noop_ptr,
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(), stream);); // NOLINT(*)
noop_ptr, stream);); // NOLINT(*)
} }
} // anonymous namespace } // anonymous namespace
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace nvfp4_recipe {
// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0;
constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0);
// Kernel to compute alpha *= amax_A * amax_B / factor
__global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A,
const float *amax_B, float *alpha_out) {
// factor is defined in the enclosing namespace
*alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv;
}
} // namespace nvfp4_recipe
} // namespace transformer_engine
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out,
cudaStream_t stream) {
NVTE_API_CALL(nvte_nvfp4_compute_per_tensor_scale);
using namespace transformer_engine;
auto *tA = convertNVTETensor(inpA);
auto *tB = convertNVTETensor(inpB);
auto *tOut = convertNVTETensor(alpha_out);
void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr;
void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr;
void *alpha_ptr = tOut->data.dptr;
// check for not null pointers
NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null");
NVTE_CHECK(amax_B_ptr != nullptr, "amax_B_ptr is null");
NVTE_CHECK(alpha_ptr != nullptr, "alpha_ptr is null");
nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>(
alpha_in, reinterpret_cast<const float *>(amax_A_ptr),
reinterpret_cast<const float *>(amax_B_ptr), reinterpret_cast<float *>(alpha_ptr));
NVTE_CHECK_CUDA(cudaGetLastError());
}
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
namespace transformer_engine { namespace transformer_engine {
namespace { namespace {
constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; constexpr int MXFP8_BLOCK_SIZE = 32;
constexpr int NVFP4_BLOCK_SIZE = 16;
constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int TB_DIM = 32;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
...@@ -314,8 +316,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ ...@@ -314,8 +316,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
const int original_K = kernel_args.original_k_list[tensor_id]; const int original_K = kernel_args.original_k_list[tensor_id];
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// Get block index in grid. Emulate 2D grid. // Get block index in grid. Emulate 2D grid.
const int num_tiles_k = K / SF_TILE_DIM_K; const int num_tiles_k = K / SF_TILE_DIM_K;
...@@ -332,9 +332,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ ...@@ -332,9 +332,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
} // namespace } // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING ||
NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); input->scaling_mode == NVTE_BLOCK_SCALING_1D ||
} input->scaling_mode == NVTE_BLOCK_SCALING_2D ||
input->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()),
"Input tensor has invalid dtype (", to_string(input->dtype()), ").");
// Do nothing if tensor is empty // Do nothing if tensor is empty
if (input->data.numel() == 0) { if (input->data.numel() == 0) {
...@@ -345,13 +349,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -345,13 +349,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
CheckInputTensor(*output, "scaling_factor_output"); CheckInputTensor(*output, "scaling_factor_output");
auto& scaling_mode = input->scaling_mode; auto& scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
"Unsupported scaling mode for swizzling.");
bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
// 1D block scaling, row-wise or colum-wise // 1D block scaling, row-wise or colum-wise
if (scaling_mode == NVTE_MXFP8_1D_SCALING) { int m, k;
const int m = if (input->has_data()) {
input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; m = input->scale_inv.shape[0];
const int k = k = input->scale_inv.shape[1];
input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; } else {
if (nvfp4) {
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
} else {
m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0];
}
}
constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4; constexpr int SF_TILE_DIM_K = 4;
...@@ -375,16 +391,35 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -375,16 +391,35 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int num_tiles_m = m / SF_TILE_DIM_M; int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K; int num_tiles_k = k / SF_TILE_DIM_K;
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle.
const bool rowwise_swizzle = input->has_data() || nvfp4;
const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4;
dim3 block_size(TB_DIM, TB_DIM); dim3 block_size(TB_DIM, TB_DIM);
if (input->has_data()) { if (rowwise_swizzle) {
int vec_load_size = (num_tiles_k - 1) % 4 + 1; int vec_load_size = (num_tiles_k - 1) % 4 + 1;
/* there is no int3 and misaligned if using int4/int2 */ /* there is no int3 and misaligned if using int4/int2 */
if (vec_load_size == 3) vec_load_size = 1; if (vec_load_size == 3) vec_load_size = 1;
int n_tiles_in_tb = TB_DIM * vec_load_size; int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_first_dim();
const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; int original_M, original_K;
void *input_scale_inv_ptr, *output_scale_inv_ptr;
if (!nvfp4 || input->has_data()) {
int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / block_scale_size;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
} else {
original_M = input->flat_last_dim();
original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->columnwise_scale_inv.dptr;
output_scale_inv_ptr = output->columnwise_scale_inv.dptr;
}
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
...@@ -392,7 +427,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -392,7 +427,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break; break;
case 2: case 2:
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
...@@ -400,7 +435,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -400,7 +435,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break; break;
case 1: case 1:
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
...@@ -408,15 +443,14 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -408,15 +443,14 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>( <<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break; break;
default: default:
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
if (input->has_columnwise_data()) { if (columnwise_swizzle) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1; int vec_load_size = (num_tiles_m - 1) % 4 + 1;
if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
int n_tiles_in_tb = TB_DIM * vec_load_size; int n_tiles_in_tb = TB_DIM * vec_load_size;
...@@ -424,6 +458,9 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -424,6 +458,9 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim(); const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
...@@ -431,8 +468,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -431,8 +468,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, output->columnwise_scale_inv.dptr, m, k,
k, original_M, original_K); original_M, original_K);
break; break;
case 2: case 2:
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
...@@ -440,8 +477,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -440,8 +477,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, output->columnwise_scale_inv.dptr, m, k,
k, original_M, original_K); original_M, original_K);
break; break;
case 1: case 1:
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
...@@ -449,19 +486,13 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -449,19 +486,13 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K> swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr, <<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m, output->columnwise_scale_inv.dptr, m, k,
k, original_M, original_K); original_M, original_K);
break; break;
default: default:
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
}
// 2D block scaling
} else {
NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans.");
} }
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
...@@ -551,6 +582,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, ...@@ -551,6 +582,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
} }
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
// TODO(nvfp4): Add NVFP4 support.
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input, void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
std::vector<Tensor*>& output, cudaStream_t stream) { std::vector<Tensor*>& output, cudaStream_t stream) {
auto num_tensors = input.size(); auto num_tensors = input.size();
...@@ -677,7 +710,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input, ...@@ -677,7 +710,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
* WIP (Phuong): * WIP (Phuong):
* - Opt for bank conflicts * - Opt for bank conflicts
* - Adding swizzle for 2d-block scaling. * - Adding swizzle for 2d-block scaling.
*/ */
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_scaling_factors); NVTE_API_CALL(nvte_swizzle_scaling_factors);
using namespace transformer_engine; using namespace transformer_engine;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <mutex> #include <mutex>
#include <utility>
#include "common.h" #include "common.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
...@@ -63,8 +64,8 @@ std::string to_string(const NVTEScalingMode &mode) { ...@@ -63,8 +64,8 @@ std::string to_string(const NVTEScalingMode &mode) {
return "NVTE_DELAYED_TENSOR_SCALING"; return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING: case NVTE_MXFP8_1D_SCALING:
return "NVTE_MXFP8_1D_SCALING"; return "NVTE_MXFP8_1D_SCALING";
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: case NVTE_NVFP4_1D_SCALING:
return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"; return "NVTE_NVFP4_1D_SCALING";
case NVTE_INVALID_SCALING: case NVTE_INVALID_SCALING:
return "NVTE_INVALID_SCALING"; return "NVTE_INVALID_SCALING";
} }
...@@ -94,12 +95,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { ...@@ -94,12 +95,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t.columnwise_scale_inv.shape, ")"); t.columnwise_scale_inv.shape, ")");
} }
} else { } else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor // Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul, 4ul}; auto block_alignment = std::vector<size_t>{128ul, 4ul};
size_t expected_x, expected_y, alignment; size_t expected_x, expected_y, alignment;
const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; const size_t block_size_rowwise = 32;
const size_t block_size_colwise = 32; const size_t block_size_colwise = 32;
if (t.has_data()) { if (t.has_data()) {
...@@ -110,6 +110,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { ...@@ -110,6 +110,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
expected_y = expected_y =
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) * DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
alignment; alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y}; const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ", "\" has invalid scale_inv shape (expected ", expected, ", got ",
...@@ -122,11 +123,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { ...@@ -122,11 +123,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
alignment; alignment;
alignment = block_alignment[0]; alignment = block_alignment[0];
expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment; expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y}; const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
t.columnwise_scale_inv.shape, ")"); t.columnwise_scale_inv.shape, ")");
} }
} else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) {
if (t.has_data()) {
const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128);
const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4);
const auto &expected = std::vector<size_t>{expected_y, expected_x};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
t.scale_inv.shape, ")");
}
if (t.has_columnwise_data()) {
const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128);
const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4);
const auto &expected = std::vector<size_t>{expected_y, expected_x};
NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
t.columnwise_scale_inv.shape, ")");
}
} }
} }
} }
...@@ -154,6 +173,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { ...@@ -154,6 +173,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
"(expected Float32 or Byte, got ", "(expected Float32 or Byte, got ",
to_string(t.columnwise_scale_inv.dtype), ")"); to_string(t.columnwise_scale_inv.dtype), ")");
} }
} else if (is_fp4_dtype(type)) {
// TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle.
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name,
"_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got ",
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ",
name,
"_columnwise_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
...@@ -195,10 +234,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -195,10 +234,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
"(expected Float32 or Float8E8M0, got ", "(expected Float32 or Float8E8M0, got ",
to_string(t.columnwise_scale_inv.dtype), ")"); to_string(t.columnwise_scale_inv.dtype), ")");
} }
} else if (is_fp4_dtype(type)) {
// FP4 output needs to have the scale_inv
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name,
"_scale_inverse has invalid dtype "
"(expected Float8E4M3, got ",
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ",
name,
"_columnwise_scale_inverse has invalid dtype "
"(expected Float8E4M3, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
// Note: amax is supported for non-FP8 output as it can be fused into the computation // Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
// and later used for quantization with no need to compute it separately // NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name); "Scale_inv is not supported for non-FP8 input ", name);
...@@ -491,6 +549,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, ...@@ -491,6 +549,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
case kNVTEColumnwiseScaleInv: case kNVTEColumnwiseScaleInv:
t->columnwise_scale_inv = *param; t->columnwise_scale_inv = *param;
break; break;
case kNVTEColumnwiseAmax:
t->columnwise_amax = *param;
break;
default: default:
NVTE_ERROR("Unknown tensor parameter!"); NVTE_ERROR("Unknown tensor parameter!");
} }
...@@ -514,6 +575,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p ...@@ -514,6 +575,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
return t.scale_inv; return t.scale_inv;
case kNVTEColumnwiseScaleInv: case kNVTEColumnwiseScaleInv:
return t.columnwise_scale_inv; return t.columnwise_scale_inv;
case kNVTEColumnwiseAmax:
return t.columnwise_amax;
default: default:
NVTE_ERROR("Unknown tensor parameter!"); NVTE_ERROR("Unknown tensor parameter!");
} }
...@@ -629,6 +692,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -629,6 +692,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size); std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
break; break;
case kNVTEQuantizationConfigRNGState:
std::memcpy(&config_.rng_state, buf, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size);
break;
default: default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#include "../common.h" #include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::detail { namespace transformer_engine::detail {
...@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor ...@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
const bool pow_2_scale, const SimpleTensor &noop_tensor, const bool pow_2_scale, const SimpleTensor &noop_tensor,
cudaStream_t stream); cudaStream_t stream);
void quantize_transpose_vector_blockwise_fp4(
const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon,
const bool return_identity, const bool return_transpose, const bool pow2_scale,
const bool swizzled_scale, const bool use_stochastic_rounding,
const NVTETensor rng_state_tensor, const bool use_2d_quantization,
const SimpleTensor &noop_tensor, cudaStream_t stream);
} // namespace transformer_engine::detail } // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
namespace transformer_engine {
#if CUDA_VERSION >= 12080
namespace quantize_transpose_nvfp4 {
namespace {
using std::int32_t;
using std::uint32_t;
using std::uint8_t;
using transformer_engine::detail::TypeExtrema;
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr int kThreadsPerWarp = 32;
// for fp4, we use uint8_t to store 2 fp4 numbers
constexpr int kNFP4PerContainer = 2;
// Hyperparameters for performance tuning
constexpr int kTileDim = 128;
// constexpr int kScaleDim = 32;
constexpr int kNVecIn = 8; // The number of elements each LDG touches
constexpr int kNVecOut = 16; // The number of elements each STG touches
constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches
constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem");
static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem");
constexpr int kSMemRow = kTileDim;
constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1;
constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem;
constexpr int kNumThreadsLoad = kTileDim / kNVecIn; // 16
constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8
// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
// for 2D block scaling, we need to reduce amax in warp
static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = {
0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080};
// max for every group_size elements in warp
template <int group_size, int shfl_down_stride>
__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) {
for (int offset = group_size / 2; offset > 0; offset /= 2) {
val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride));
}
return val;
}
template <typename ScaleType>
__device__ __forceinline__ ScaleType ComputeDecodeScaleFP4(const float amax,
const float global_encode_scale) {
float decode_scale = amax / TypeExtrema<fp4e2m1>::max;
decode_scale = decode_scale * global_encode_scale;
decode_scale = fminf(decode_scale, TypeExtrema<float>::max);
return static_cast<ScaleType>(decode_scale);
}
template <typename ScaleType>
__device__ __forceinline__ float ComputeEncodeScaleFP4(ScaleType decode_scale,
const float global_decode_scale) {
return fminf(1.0f / (static_cast<float>(decode_scale) * global_decode_scale),
TypeExtrema<float>::max);
}
template <typename IType, typename ScaleType>
__device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scale) {
return static_cast<float>(input) * encode_scale;
}
__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) {
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.f || global_encode_scale == 0.f) {
return 1.f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t* const rbits_arr = reinterpret_cast<uint32_t*>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
template <class ScaleType>
__device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, size_t col_idx,
uint32_t col_length) {
// This function takes in indices from the scale factor matrix and returns an offset in the
// swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled
// index). col_length is the column length of the scale factor matrix. tile_scales_inv is the
// pointer to the scale factor matrix.
// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts
// For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4
// columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4
// columns.
// NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix.
// To think in high level, the swizzled scale factor matrix could be composed as:
// unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8)
// cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64
// rb_cnt = M // 128 # Assuming M is divisible by 128
// tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4)
// tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
// swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4))
constexpr uint32_t kTotalRowsPerBaseBlock = 128;
constexpr uint32_t kRowsPerBaseBlockCol = 32;
constexpr uint32_t kColsPerBaseBlockCol = 4;
const size_t rb = row_idx / kTotalRowsPerBaseBlock;
const size_t rem = row_idx % kTotalRowsPerBaseBlock;
const size_t d4 = rem / kRowsPerBaseBlockCol;
const size_t d3 = rem % kRowsPerBaseBlockCol;
const size_t cbg = col_idx / kColsPerBaseBlockCol;
const size_t d5 = col_idx % kColsPerBaseBlockCol;
const size_t cbg_cnt = DIVUP(col_length, kColsPerBaseBlockCol);
// row-major offset in the logical shape
// (rb_cnt , cbg_cnt , 32 , 4 , 4)
// Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] /
// 32 = [0-4])
return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5;
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t out_4x;
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t"
"}"
: "=h"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
// NOTE: rbits unused for rn.
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
asm volatile(
"{\n"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
template <bool kApplyStochasticRounding>
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const uint32_t rbits) {
if constexpr (kApplyStochasticRounding) {
return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits);
} else {
return cvt_fp32_to_fp4_4x_with_rn(in01, in23, rbits);
}
}
template <bool kReturnIdentity, bool kReturnTranspose, bool kIsE8Scaling, bool kAligned,
typename CType, typename IType, typename OType, typename ScaleType, bool kSwizzledScale,
bool kApplyStochasticRounding, bool kIs2DBlockScaling>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
const IType* const input, const float* global_amax, OType* const output_c,
OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t,
const size_t row_length, const size_t num_rows, const size_t scale_stride_x,
const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y,
const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state,
const float* noop_ptr) {
constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecContainer>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
const size_t block_idx_x = blockIdx.x;
const size_t block_idx_y = blockIdx.y;
const size_t rng_sequence =
threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
if constexpr (kIs2DBlockScaling && kIsE8Scaling) {
return;
}
if constexpr (kIs2DBlockScaling && !kReturnIdentity) {
return;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr int kFP4BlockScalingSize = 16;
constexpr int k2DBlockAmaxDim = kIs2DBlockScaling ? (kTileDim / kFP4BlockScalingSize) : 1;
constexpr int kNumRowsPerWarp = kThreadsPerWarp / kNumThreadsStore; // 4
constexpr int k2DBlockAmaxReduceDim =
kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1;
__shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim];
__shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim];
// Step 1: Load input to shared memory
{
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = (c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0); // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory
// for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
const int kNumThreadsReduce = kScaleBlockDim / kNVecOut;
const float global_encode_scale =
kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]);
const float global_decode_scale = 1.0 / global_encode_scale;
// Step 2: Cast and store to output_c
if constexpr (kReturnIdentity) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele =
(c_g < row_length ? min(static_cast<size_t>(kNVecOut / kNFP4PerContainer),
(row_length - c_g) / kNFP4PerContainer)
: 0); // For not aligned case
OType* output_g =
&output_c[(r_g * row_length + c_g) / kNFP4PerContainer]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
if constexpr (kIsE8Scaling) {
#pragma unroll
for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
}
// doing shuffle sync for 2D block scaling (not applicable for E8 scaling)
if constexpr (kIs2DBlockScaling) {
// first amax shuffle sync in warp, then reduce in smem
// T0 T8 T16 T24 should do amax reduction together
constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
int tid_in_warp_x = threadIdx.x % kNumThreadsStore;
int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp;
CType amax_warp_reduced = groupMax<kNumRowsPerWarp, kNumThreadsStore>(
amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]);
// now T0 ~ T8 in each warp has the reduced amax values
int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y;
if (tid_in_warp_y == 0) {
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]
[warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced;
}
__syncthreads();
if (data_row_idx % kFP4BlockScalingSize == 0) {
CType amax_2d = 0.0;
for (int i = 0; i < k2DBlockAmaxReduceDim; i++) {
amax_2d = fmaxf(amax_2d,
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]);
}
amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d;
}
__syncthreads();
// every thread now knows 2D amax
amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x];
}
// Step 2.4: Compute scale
ScaleType scale_inv = ComputeDecodeScaleFP4<ScaleType>(amax, global_encode_scale);
float encode_scale = ComputeEncodeScaleFP4<ScaleType>(scale_inv, global_decode_scale);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
write_scale_inv &= (c_g < row_length);
}
if (write_scale_inv) {
size_t row_idx = block_idx_y * kTileDim + r_s;
size_t col_idx = block_idx_x * (kNumThreadsStore / kNumThreadsReduce) +
(threadIdx.x % kNumThreadsStore) / kNumThreadsReduce;
if constexpr (kSwizzledScale) {
size_t offset = scale_factor_swizzled_offset<ScaleType>(
row_idx, col_idx, DIVUP(row_length, kScaleBlockDim));
tile_scales_inv_c[offset] = scale_inv;
} else {
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) {
// Pack two elements into __nv_bfloat162
float2 f2_a;
float2 f2_b;
f2_a.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[i].data.elt[0], encode_scale);
f2_a.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[i].data.elt[1], encode_scale);
f2_b.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[i + 1].data.elt[0], encode_scale);
f2_b.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[i + 1].data.elt[1], encode_scale);
const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x<kApplyStochasticRounding>(f2_a, f2_b, rbits);
output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0];
output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1];
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory
// for not aligned case)
output_g += stride_g / kNFP4PerContainer;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if constexpr (kReturnTranspose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory
size_t r_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Row in global memory
const size_t c_g = block_idx_y * kTileDim + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = (c_g < num_rows ? min(static_cast<size_t>(kNVecOut / kNFP4PerContainer),
(num_rows - c_g) / kNFP4PerContainer)
: 0); // For not aligned case
OType* output_g =
&output_t[(r_g * num_rows + c_g) / kNFP4PerContainer]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
if constexpr (kIs2DBlockScaling) {
// TODO(zhongbo): 2D block scaling, directly read from amax_smem
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
constexpr int kNumColsPerWarp =
kThreadsPerWarp / kNumThreadsStore * kNVecSMem; // 8 elements
constexpr int kNumWarpsPerBlock =
kThreadsPerBlock / kThreadsPerWarp; // 8 warps per block
constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock;
int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp;
int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore;
int data_col_idx = iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x;
amax = amax_smem[tid_in_warp_y][data_col_idx / kFP4BlockScalingSize];
} else {
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
}
// Step 3.3: Reduce amax
if constexpr (kIsE8Scaling) {
#pragma unroll
for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
}
// Step 3.4: Compute scale
ScaleType scale_inv = ComputeDecodeScaleFP4<ScaleType>(amax, global_encode_scale);
float encode_scale = ComputeEncodeScaleFP4<ScaleType>(scale_inv, global_decode_scale);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
write_scale_inv &= (c_g < num_rows);
}
if (write_scale_inv) {
size_t row_idx = block_idx_x * kTileDim + c_s * kNVecSMem + smem_idx;
size_t col_idx = (block_idx_y * (kNumThreadsStore / kNumThreadsReduce) +
(threadIdx.x % kNumThreadsStore) / kNumThreadsReduce);
if constexpr (kSwizzledScale) {
size_t offset = scale_factor_swizzled_offset<ScaleType>(
row_idx, col_idx, DIVUP(num_rows, kScaleBlockDim));
tile_scales_inv_t[offset] = scale_inv;
} else {
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) {
// Pack two elements into __nv_bfloat162
float2 f2_a;
float2 f2_b;
f2_a.x =
ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * i].data.elt[smem_idx], encode_scale);
f2_a.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * i + 1].data.elt[smem_idx],
encode_scale);
f2_b.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * (i + 1)].data.elt[smem_idx],
encode_scale);
f2_b.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx],
encode_scale);
const uint32_t rbits =
kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x<kApplyStochasticRounding>(f2_a, f2_b, rbits);
output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0];
output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1];
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows / kNFP4PerContainer);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows / kNFP4PerContainer, 0,
num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global
// memory for not aligned case)
output_g += stride_g / kNFP4PerContainer;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
} // namespace
} // namespace quantize_transpose_nvfp4
#endif // CUDA_VERSION >= 12080
namespace detail {
void quantize_transpose_vector_blockwise_fp4(
const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon,
const bool return_identity, const bool return_transpose, const bool pow2_scale,
const bool swizzled_scale, const bool use_stochastic_rounding,
const NVTETensor rng_state_tensor, const bool use_2d_quantization,
const SimpleTensor& noop_tensor, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4);
#if CUDA_VERSION >= 12080
// pow 2 scale is for MXFP4 since it's using E8M0 scaling
// raise error if pow2_scale is true
NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now");
if (!return_identity && !return_transpose) {
return;
}
if (use_2d_quantization && !return_identity) {
return;
}
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
num_elements *= input.shape.at(i);
}
// Early return if the input tensor is empty
if (num_elements == 0) {
return;
}
size_t scale_stride_x = 0;
size_t scale_stride_y = 0;
if (return_identity) {
scale_stride_x = 1;
scale_stride_y = scale_inv.shape[1];
}
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (return_transpose) {
scale_t_stride_x = 1;
scale_t_stride_y = scale_inv_t.shape[1];
}
using namespace transformer_engine::quantize_transpose_nvfp4;
const size_t num_blocks_x = DIVUP(row_length, static_cast<size_t>(kTileDim));
const size_t num_blocks_y = DIVUP(num_rows, static_cast<size_t>(kTileDim));
// noop tensor for cuda graph
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
const size_t* rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor& rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t*>(rng_state_te_tensor.data.dptr);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(
output.dtype, 2, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16;
constexpr bool kPow2Scale = false;
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity, kReturnIdentity,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
swizzled_scale, kSwizzledScale,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kApplyStochasticRounding,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_2d_quantization, kIs2DBlockScaling,
size_t smem_bytes = kSMemSize * sizeof(InputType);
auto kernel = block_scaled_1d_cast_transpose_kernel<
kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned,
float, InputType, OutputType, ScaleType, kSwizzledScale,
kApplyStochasticRounding, kIs2DBlockScaling>;
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_bytes);
NVTE_CHECK(err == cudaSuccess,
"Failed to set dynamic shared memory size.");
} kernel<<<grid, kThreadsPerBlock, smem_bytes,
stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<const float*>(global_amax.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<ScaleType*>(scale_inv.dptr),
reinterpret_cast<ScaleType*>(scale_inv_t.dptr), row_length,
num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, kScaleBlockDim, epsilon, rng_state,
noop_ptr);) // kIs2DBlockScaling
) // kApplyStochasticRounding
) // kSwizzledScale
) // kAligned
) // kReturnTranspose
) // kReturnIdentity
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace detail
} // namespace transformer_engine
...@@ -598,6 +598,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -598,6 +598,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
...@@ -828,6 +829,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -828,6 +829,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
} }
} }
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
...@@ -947,6 +949,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -947,6 +949,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t in_gate_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
...@@ -1260,7 +1263,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -1260,7 +1263,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_gated<ParamOP, ActOP>(gated_input, output, stream); cast_gated<ParamOP, ActOP>(gated_input, output, stream);
} }
} }
} else if (is_mxfp_scaling(output->scaling_mode)) { } else if (is_mxfp8_scaling(output->scaling_mode)) {
if (use_tma_kernels) { if (use_tma_kernels) {
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream); cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
} else { } else {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "math.h" #include "math.h"
#include "nvfp4_transpose.cuh"
#include "ptx.cuh" #include "ptx.cuh"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
...@@ -108,6 +109,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -108,6 +109,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
// helps resolving bank conflicts in shmem // helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK; const int bank_group = thread_lane / THREADS_PER_BANK;
...@@ -135,8 +138,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -135,8 +138,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem); IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem); IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise); OType *out_rowwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
...@@ -284,7 +288,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -284,7 +288,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float scaled_out = in * block_scale_inverse; const float scaled_out = in * block_scale_inverse;
const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out); out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
} }
} }
...@@ -408,10 +412,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -408,10 +412,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor // 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent = const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise; const int stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
if (rowwise_scale_is_within_bounds) {
scales_rowwise[scale_idx] = biased_exponent; scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
...@@ -439,7 +445,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -439,7 +445,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
} }
} }
...@@ -454,19 +460,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -454,19 +460,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y; const int global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X; const int global_offset_X = block_offset_X;
const size_t buff_offset = buff * BUFF_DIM; const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X, reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff_offset])); global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset]));
} }
if constexpr (COLWISE_SCALING) { if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X, reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff_offset])); global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset]));
} }
// Create a "bulk async-group" out of the previous bulk copy operation. // Create a "bulk async-group" out of the previous bulk copy operation.
...@@ -487,18 +493,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -487,18 +493,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts // Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem); float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const size_t shmem_thread_offset = const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll #pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) { for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e; const int j = w * PACK_SIZE + e;
const size_t shmem_elt_idx = swizzled_group_offset + e; const int shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
} }
} }
...@@ -506,15 +512,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -506,15 +512,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int i = 0; i < THREADS_Y; ++i) { for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32] // Add extra element offset per MXFP8 scaling block [1x32]
const size_t scaling_block = threadIdx.x / SCALE_DIM_X; const int scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias += thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
} }
} }
const size_t dbias_stride = cols; const int dbias_stride = cols;
const size_t dbias_offset_Y = blockIdx.y; const int dbias_offset_Y = blockIdx.y;
const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) { if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias; dbias_workspace[dbias_idx] = thread_partial_dbias;
...@@ -536,6 +542,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -536,6 +542,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
} // namespace mxfp8_kernel } // namespace mxfp8_kernel
namespace nvfp4_kernel {
using namespace ptx;
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 16;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = 32;
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return static_cast<fp8e4m3>(block_amax * rcp_6f * S_enc);
}
#define DIRECT_SCALING_FACTORS_STORE 1
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType, bool COLWISE_SCALING, size_t CHUNK_DIM_Y,
size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_colwise,
fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0,
const float *noop, float *const amax_ptr,
const float *const nvfp4_second_stage_scale_ptr, const size_t rows,
const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool ROWWISE_SCALING = true;
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE;
static_assert(BUFF_DIM_Y >= SCALE_DIM_Y &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size
constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE;
// static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of
// // threads to process one row in a single iteration
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0);
constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0);
constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0);
constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise_data);
fp8e4m3 *out_rowwise_scales_sh =
reinterpret_cast<fp8e4m3 *>(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
e8m0_t *out_colwise_scales_sh = reinterpret_cast<e8m0_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc =
(nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr);
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const int buff_offset_in = buff * BUFF_IN_DIM;
const int buff_offset_out = buff * BUFF_OUT_DIM;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_IN_DIM;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM_Y];
IType in_colwise_IType[SCALE_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(block_amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
if (colwise_scale_is_within_bounds) {
scales_colwise_e8m0[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (int it = 0; it < ITERATIONS_ROWWISE; ++it) {
const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const int shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const int shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc);
#if DIRECT_SCALING_FACTORS_STORE
// Check boundaries
if (rowwise_scale_is_within_bounds) {
const int scales_offset_Y =
scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE;
const int scales_offset_X = scales_offset_X_rowwise;
const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X;
scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8;
}
#else
const int shmem_scales_offset_Y =
stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise;
const int shmem_scales_offset_X = tid_X_rowwise;
const int scale_idx =
shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X;
out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8;
#endif
// Compute "correct" per-block encoding scaling factor
const float block_scale_inverse =
__fdiv_rn(S_enc, static_cast<float>(S_dec_b_fp8)); // S_enc_b_fp8
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out; // Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in01 = in_IType[w].data.elt[2 * e];
in23 = in_IType[w].data.elt[2 * e + 1];
} else if constexpr (IS_CACHED_ACT_OP) {
in01.x = in_cached[w].data.elt[4 * e];
in01.y = in_cached[w].data.elt[4 * e + 1];
in23.x = in_cached[w].data.elt[4 * e + 2];
in23.y = in_cached[w].data.elt[4 * e + 3];
} else {
const int j = w * PACK_SIZE + 4 * e;
in01.x = in_compute_rowwise[j];
in01.y = in_compute_rowwise[j + 1];
in23.x = in_compute_rowwise[j + 2];
in23.y = in_compute_rowwise[j + 3];
}
fp4e2m1x4 &out_quad = reinterpret_cast<fp4e2m1x4 &>(out.data.elt[e]);
ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
__builtin_assume(block_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM;
const int buff_offset_mxfp8 = buff * BUFF_IN_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset_nvfp4]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset_mxfp8]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
#if !DIRECT_SCALING_FACTORS_STORE
// Vectorized store of scaling factors.
// Each thread stores multiple scaling factors in one store instruction.
if constexpr (ROWWISE_SCALING) {
// Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise;
const int scale_idx_global =
scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise;
const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) &&
(scales_offset_X_rowwise < (cols / SCALE_DIM_X))) {
using ScalesVec_t = Vec<fp8e4m3, NVFP4_SCALING_FACTORS_PER_CHUNK_ROW>;
const ScalesVec_t &scales =
*reinterpret_cast<ScalesVec_t *>(&out_rowwise_scales_sh[scale_idx_shmem]);
scales.store_to(&scales_rowwise_e4m3[scale_idx_global]);
}
}
#endif
float chunk_amax = 0.0f;
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
chunk_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(thread_amax, warp_id);
}
if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, chunk_amax);
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_kernel
constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128; constexpr size_t FP8_CHUNK_DIM_X = 128;
constexpr size_t FP8_THREADS_PER_CHUNK = 128; constexpr size_t FP8_THREADS_PER_CHUNK = 128;
...@@ -898,7 +1426,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, ...@@ -898,7 +1426,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
} }
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)> template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
...@@ -1179,6 +1707,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1179,6 +1707,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
); // NOLINT(*) ); // NOLINT(*)
} }
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &)>
void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) {
using namespace nvfp4_kernel;
using namespace ptx;
checkCuDriverContext(stream);
NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated.");
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
bool use_colwise_scaling = output->has_columnwise_data();
if (use_colwise_scaling) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Columnwise scaling tensor must be allocated");
}
CheckNoopTensor(*noop, "cast_noop");
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = output->scale_inv.shape[1];
const size_t scale_stride_colwise =
use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;
fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast<fp8e4m3 *>(output->scale_inv.dptr);
e8m0_t *const scales_colwise_e8m0_ptr =
use_colwise_scaling ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const ScalingType scaling_type =
use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const nvfp4_second_stage_scale_ptr =
reinterpret_cast<const float *>(output->scale.dptr);
// Output data type is only required for the column-wise MXFP8 scaling.
// It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work
const DType output_data_type =
use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4);
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8);
}
constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t);
constexpr size_t in_mem = buff_size_aligned_in;
const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4;
const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0;
const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales;
const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0;
const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem +
out_rowwise_scales_mem + out_colwise_scales_mem +
TMA_SHMEM_ALIGNMENT;
const size_t dshmem_size = in_mem + out_mem;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}); // NOLINT(*)
); // NOLINT(*)
}
namespace detail { namespace detail {
using Empty = transformer_engine::Empty; using Empty = transformer_engine::Empty;
...@@ -1386,20 +2049,33 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1386,20 +2049,33 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
auto dbias_tensor = convertNVTETensor(dbias); auto dbias_tensor = convertNVTETensor(dbias);
auto workspace_tensor = convertNVTETensor(workspace); auto workspace_tensor = convertNVTETensor(workspace);
const QuantizationConfig *quant_config_cpp = // Quantization config
reinterpret_cast<const QuantizationConfig *>(quant_config); QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// extract noop tensor from quant_config_cpp if it's not null // Check for unsupported options
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; if (quant_config_cpp.stochastic_rounding) {
const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) { switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
NVTE_CHECK(output_tensor->has_data(), NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!"); "Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream);
} else { } else {
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, float, ParamOP, OP>( cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, float, ParamOP, OP>(
*input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor,
...@@ -1407,51 +2083,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1407,51 +2083,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
} }
} else if (output_tensor->has_data()) { } else if (output_tensor->has_data()) {
fp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>( fp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream); workspace_tensor, stream);
} }
break; break;
} }
case NVTE_MXFP8_1D_SCALING: { case NVTE_MXFP8_1D_SCALING: {
mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>( mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream); workspace_tensor, stream);
break; break;
} }
case NVTE_NVFP4_1D_SCALING: {
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 &&
output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4_quantize_transpose<IS_ACT, ParamOP, OP, true>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4_quantize_transpose<IS_ACT, ParamOP, OP, false>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for "
"2D quantization");
quantize_transpose_vector_blockwise_fp4(
/*input=*/input_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: { case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise( quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor.data, stream); /*noop_tensor=*/noop_tensor->data, stream);
break; break;
} }
case NVTE_BLOCK_SCALING_1D: { case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) { if (output_tensor->has_data()) {
bool rowwise_compact = quant_config_cpp bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
? quant_config_cpp->float8_block_scale_tensor_format == Float8BlockScaleTensorFormat::COMPACT);
Float8BlockScaleTensorFormat::COMPACT
: false;
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
} }
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = quant_config_cpp bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
? quant_config_cpp->float8_block_scale_tensor_format == Float8BlockScaleTensorFormat::COMPACT);
Float8BlockScaleTensorFormat::COMPACT
: false;
columnwise_option = columnwise_compact columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
...@@ -1459,7 +2174,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1459,7 +2174,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_vector_blockwise( quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor.data, stream); columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break; break;
} }
default: default:
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include <cfloat> #include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits> #include <limits>
#include "../common.h" #include "../common.h"
...@@ -26,6 +28,7 @@ ...@@ -26,6 +28,7 @@
#include "math.h" #include "math.h"
#include "ptx.cuh" #include "ptx.cuh"
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h" #include "transformer_engine/transpose.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -226,7 +229,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -226,7 +229,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
...@@ -247,7 +250,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str ...@@ -247,7 +250,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str
); // NOLINT(*) ); // NOLINT(*)
} }
static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
bool use_rowwise_scaling = input.has_data(); bool use_rowwise_scaling = input.has_data();
bool use_colwise_scaling = input.has_columnwise_data(); bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -331,6 +334,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ...@@ -331,6 +334,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
#if CUDA_VERSION >= 12080
template <typename OType>
__global__ void __launch_bounds__(512)
dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales,
const float *const tensor_amax, const size_t N, const size_t M,
const size_t scale_stride) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
union fp4vec {
uint64_t vec;
fp4e2m1x4 small_vec[4];
};
using OVec = Vec<OType, 4>;
const uint64_t *const input_vectorized = reinterpret_cast<const uint64_t *>(input);
OVec *output_vec = reinterpret_cast<OVec *>(output);
const size_t my_index = x + y * M;
const size_t my_scale_index = x + y * scale_stride;
const size_t my_output_index = (x + y * M) * 4;
fp4vec value;
value.vec = input_vectorized[my_index];
fp8e4m3 scale = scales[my_scale_index];
float amax = *tensor_amax;
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
float4 current = static_cast<float4>(value.small_vec[i]);
OVec out;
out.data.elt[0] = static_cast<OType>(current.x * final_scale);
out.data.elt[1] = static_cast<OType>(current.y * final_scale);
out.data.elt[2] = static_cast<OType>(current.z * final_scale);
out.data.elt[3] = static_cast<OType>(current.w * final_scale);
output_vec[my_output_index + i] = out;
}
}
#endif // CUDA_VERSION
void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#if CUDA_VERSION >= 12080
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
const size_t Mread = M / FP4_BLOCK_SIZE;
const size_t total = N * Mread;
const size_t threads = 512;
const size_t blocks = DIVUP(total, threads);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
dequantize_fp4_kernel<<<blocks, threads, 0, stream>>>(
input.data.dptr, reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp8e4m3 *>(input.scale_inv.dptr),
reinterpret_cast<float *>(input.amax.dptr), N, Mread,
input.scale_inv.shape.back());); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");
#endif // CUDA_VERSION >= 12080
}
} // namespace dequantization } // namespace dequantization
namespace detail { namespace detail {
...@@ -339,16 +417,24 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -339,16 +417,24 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
CheckInputTensor(input, "cast_input"); CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output"); CheckOutputTensor(*output, "cast_output");
if (is_tensor_scaling(input.scaling_mode)) { switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
dequantization::fp8_dequantize(input, output, stream); dequantization::fp8_dequantize(input, output, stream);
} else if (is_mxfp_scaling(input.scaling_mode)) { break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) { if (is_supported_by_CC_100()) {
dequantization::mxfp8_dequantize(input, output, stream); dequantization::mxfp8_dequantize(input, output, stream);
} else { } else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
} }
} else { break;
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING }
case NVTE_NVFP4_1D_SCALING: {
dequantization::fp4_dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
} }
} }
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file nvfp4_transpose.cuh
* \brief CUDA kernels to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#include <cfloat>
#include "../common.h"
#include "../utils.cuh"
#include "curanddx.hpp"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
#if CUDA_VERSION > 12080
namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
using namespace ptx;
using nvfp4_scale_t = fp8e4m3;
constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts)
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_NUM = 128;
constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM;
constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM;
constexpr size_t RNG_GENS_PER_THREAD =
SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers
constexpr size_t TILE_DIM_Y = 32;
constexpr size_t TILE_DIM_X = 128;
// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D
constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8
constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y;
constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X;
constexpr size_t STAGES = TILES_Y * TILES_X;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = TILE_DIM_Y;
constexpr size_t BUFF_DIM_X = TILE_DIM_X;
constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM;
// Input buffer (BF16)
constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X;
constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X;
// Output buffer (NVFP4)
constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8;
constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X;
// Output transpose buffer (NVFP4)
constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X;
constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8;
constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM / PACK_SIZE;
constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM;
constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8
constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16
constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2
constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM;
constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE;
static_assert(BUFF_DIM_Y >= SCALE_DIM &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using namespace detail;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
const float S_dec_b = block_amax / fp4_max * S_enc;
return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max));
}
// Compute the global encode scale factor for a given global amax
__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) {
using namespace detail;
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits);
} else {
return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits);
}
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits);
} else {
return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr,
const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
// const size_t tid_X_t = 0;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise =
(row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(elts, block_scale_inverse_2x,
rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("sm_100 or higher is required.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr,
const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
// NEW: 2D Block-based scaling constants
constexpr size_t BLOCK_DIM = 16;
constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2
constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8
constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile
constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
const size_t warp_id = threadIdx.x / 32;
const size_t lane_id = threadIdx.x % 32;
float thread_amax = 0.0f;
const size_t block_in_warp = lane_id / BLOCKS_PER_WARP;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
__shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1];
// Helper function for warp reduction
auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float {
#pragma unroll
for (int delta = 8; delta >= 1; delta /= 2) {
float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta);
thread_amax = fmaxf(thread_amax, other_amax);
}
return thread_amax;
};
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
#pragma unroll
for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
const size_t block_in_tile_y = block_iter;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
for (int elem = 0; elem < BLOCK_DIM; elem += 2) {
const size_t elem_0_row = block_iter * BLOCK_DIM + elem;
const size_t elem_1_row = elem_0_row + 1;
const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
const size_t elem_1_col = elem_0_col;
const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col;
const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col;
IType2 val_2x;
val_2x.x = in_sh[shmem_offset_0];
val_2x.y = in_sh[shmem_offset_1];
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x);
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else {
for (int elem = 0; elem < BLOCK_DIM; ++elem) {
const size_t elem_row = block_iter * BLOCK_DIM + elem;
const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
// Bounds checking
const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows);
const bool col_out_of_bounds = (block_offset_X + elem_col >= cols);
if (!row_out_of_bounds && !col_out_of_bounds) {
const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col;
float elt = static_cast<float>(in_sh[shmem_offset]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset] = static_cast<IType>(elt);
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
}
}
// Warp reduction to get block amax
block_amax = warp_reduce_amax(thread_amax, block_in_warp);
if (lane_id == 0 || lane_id == 16) {
block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax;
}
}
// sync thread to ensure block_amax_matrix is done storing
__syncthreads();
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 3. Scale elements
// Load data in
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
}
} else {
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// // Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(elts, block_scale_inverse_2x,
rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = tid_X_rowwise;
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
}
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
bool use_2d_quantization>
void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) {
#if CUDA_VERSION > 12080
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data();
using namespace nvfp4_transpose;
using namespace ptx;
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", false);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if (return_transpose) {
NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype),
"Transposed output must have FP4 type.");
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Transposed scaling tensor must be allocated");
}
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
NVTE_CHECK(cols % 32 == 0,
"Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_NUM;
const size_t scale_stride = output->scale_inv.shape[1];
const size_t scale_stride_transpose = output->columnwise_scale_inv.shape[1];
nvfp4_scale_t *const scales_ptr = reinterpret_cast<nvfp4_scale_t *>(output->scale_inv.dptr);
nvfp4_scale_t *const scales_transpose_ptr =
reinterpret_cast<nvfp4_scale_t *>(output->columnwise_scale_inv.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const amax_rowwise_ptr = reinterpret_cast<const float *>(output->amax.dptr);
const float *const amax_colwise_ptr =
reinterpret_cast<const float *>(output->columnwise_amax.dptr);
const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr;
const size_t *rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
}
using IType = bf16;
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
alignas(64) CUtensorMap tensor_map_output_transpose{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
4);
if (return_transpose) {
create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows,
BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_data_mem = buff_size_aligned_out;
constexpr size_t out_data_transpose_mem = buff_size_aligned_out;
constexpr size_t out_scales_transpose_mem = buff_size_scales;
constexpr size_t out_mem = out_data_mem + out_data_transpose_mem;
constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, USE_STOCHASTIC_ROUNDING,
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
auto kernel = nvfp4_transpose_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
if constexpr (use_2d_quantization) {
kernel = nvfp4_transpose_kernel_2D<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr,
scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols,
scale_stride, scale_stride_transpose, rng_state);
}););
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION > 12080
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION >= 12080
namespace transformer_engine { namespace transformer_engine {
namespace ptx { namespace ptx {
...@@ -117,9 +121,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { ...@@ -117,9 +121,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS); return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
} }
#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \
((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM103_ALL)))
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ #if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out; uint16_t out;
asm volatile( asm volatile(
"{\n" "{\n"
...@@ -222,18 +230,86 @@ struct alignas(2 * sizeof(T)) FPx2 { ...@@ -222,18 +230,86 @@ struct alignas(2 * sizeof(T)) FPx2 {
T y; T y;
}; };
template <typename T>
struct FPx4 {
T x1;
T x2;
T x3;
T x4;
};
template <typename T>
struct Type2x {};
template <>
struct Type2x<float> {
using type = float2;
};
template <>
struct Type2x<bf16> {
using type = __nv_bfloat162;
};
template <>
struct Type2x<fp16> {
using type = __half2;
};
using floatx2 = FPx2<float>; using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>; using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>; using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>; using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>; using fp8e5m2x2 = FPx2<fp8e5m2>;
using floatx4 = FPx4<float>;
using bf16x4 = FPx4<bf16>;
using fp16x4 = FPx4<fp16>;
using fp8e4m3x4 = FPx4<fp8e4m3>;
using fp8e5m2x4 = FPx4<fp8e5m2>;
static_assert(sizeof(floatx2) == 8); static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4); static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2);
#if CUDA_VERSION >= 12080
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
static_assert(sizeof(fp4e2m1x2) == 1);
static_assert(sizeof(fp4e2m1x4) == 2);
#endif // CUDA_VERSION >= 12080
// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1
// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6.
// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures:
// sm_100a
// sm_101a
// sm_120a
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value
// converted from input a is stored in the upper 4 bits of d and the value converted
// from input b is stored in the lower 4 bits of d.
// SIMD like "Fused" cast + multiplication (x4)
#if CUDA_VERSION >= 12080
template <typename Tx2>
__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23,
const float scale) {
const float x0 = static_cast<float>(in01.x) * scale;
const float x1 = static_cast<float>(in01.y) * scale;
const float x2 = static_cast<float>(in23.x) * scale;
const float x3 = static_cast<float>(in23.y) * scale;
out = fp4e2m1x4(make_float4(x0, x1, x2, x3));
}
#endif // CUDA_VERSION >= 12080
// SIMD like "Fused" cast + multiplication (x2) // SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) { const floatx2 &scale) {
...@@ -369,7 +445,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const ...@@ -369,7 +445,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
"r"(reinterpret_cast<const uint32_t &>(p2))); "r"(reinterpret_cast<const uint32_t &>(p2)));
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
.value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \ pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
......
...@@ -35,6 +35,26 @@ constexpr uint32_t THREADS_PER_WARP = 32; ...@@ -35,6 +35,26 @@ constexpr uint32_t THREADS_PER_WARP = 32;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Device-side error
#define NVTE_DEVICE_ERROR(message) \
do { \
printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
__func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \
(message)); \
assert(0); \
} while (false)
// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message) \
do { \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
threadIdx.y == 0 && threadIdx.z == 0) { \
NVTE_DEVICE_ERROR(message); \
} \
} while (false)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y}; return {a.x + b.x, a.y + b.y};
} }
......
...@@ -89,3 +89,5 @@ GemmParallelModes = ("row", "column", None) ...@@ -89,3 +89,5 @@ GemmParallelModes = ("row", "column", None)
dist_group_type = torch.distributed.ProcessGroup dist_group_type = torch.distributed.ProcessGroup
MXFP8_BLOCK_SCALING_SIZE = 32 MXFP8_BLOCK_SCALING_SIZE = 32
NVFP4_BLOCK_SCALING_SIZE = 16
...@@ -13,6 +13,8 @@ from ..utils import get_sm_count, _empty_tensor ...@@ -13,6 +13,8 @@ from ..utils import get_sm_count, _empty_tensor
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor.utils import is_experimental
from ..experimental.gemm import experimental_gemm
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [ __all__ = [
...@@ -77,6 +79,24 @@ def general_gemm( ...@@ -77,6 +79,24 @@ def general_gemm(
if not out.is_contiguous(): if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.") raise ValueError("Output tensor is not contiguous.")
# If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation
if is_experimental(A) or is_experimental(B):
return experimental_gemm(
A,
B,
workspace,
out_dtype,
quantization_params,
gelu,
gelu_in,
accumulate,
layout,
out,
bias,
use_split_accumulator,
grad,
)
debug_quantizer = None debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer): if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params debug_quantizer = quantization_params
......
...@@ -12,6 +12,20 @@ ...@@ -12,6 +12,20 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/*! convert fp4 data shape back to original shape */
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose) {
std::vector<size_t> ret;
size_t start_idx = (transpose) ? 1 : 0;
for (size_t i = start_idx; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
ret.push_back(shape.back() * 2);
if (transpose) {
ret.push_back(shape.front());
}
return ret;
}
std::vector<size_t> getTensorShape(const at::Tensor& t) { std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape; std::vector<size_t> shape;
for (auto s : t.sizes()) { for (auto s : t.sizes()) {
...@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) { ...@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) {
return ((value + multiple - 1) / multiple) * multiple; return ((value + multiple - 1) / multiple) * multiple;
} }
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h> #include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_stream.h> #include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
...@@ -194,20 +195,25 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -194,20 +195,25 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override; DType dtype) const override;
/*! @brief Construct a high precision tensor giving it this quantizer's amax /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
*
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with * The amax is zeroed out. Most TE kernels that output amax expect
a kernel computing the amax, which might expect the amax to be initialized to zero * amax to be initialized to zero.
*/ */
std::pair<TensorWrapper, py::object> create_hp_tensor_with_amax(const std::vector<size_t>& shape, std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
DType dtype); const std::vector<size_t>& shape, DType dtype);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override; std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out, void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override; const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Convert to a quantized data format avoiding amax computation */ /*! @brief Quantize to FP8, skipping local amax computation
*
* The quantizer's amax pointer is assumed to already hold the local
* amax. The amax may still be reduced across the amax reduction
* group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt); const std::optional<TensorWrapper>& noop_flag = std::nullopt);
...@@ -277,6 +283,60 @@ class MXFP8Quantizer : public Quantizer { ...@@ -277,6 +283,60 @@ class MXFP8Quantizer : public Quantizer {
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const; std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
class NVFP4Quantizer : public Quantizer {
public:
// fp4 dtype
DType dtype;
// amax reduction for low precision FP4 AG
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
// random hadamard transform
bool with_rht;
bool with_post_rht_amax;
// 2D block scaling
bool with_2d_quantization;
bool stochastic_rounding;
int rht_matrix_random_sign_mask_t;
at::Tensor rht_matrix;
explicit NVFP4Quantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; }
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
TensorWrapper& quantized_tensor, DType dtype);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Quantize to NVFP4, skipping local amax computation
*
* The input tensor's amax pointer is assumed to already hold the
* local amax. The amax may still be reduced across the amax
* reduction group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out);
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
};
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer); std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
std::vector<size_t> getTensorShape(const at::Tensor& t); std::vector<size_t> getTensorShape(const at::Tensor& t);
...@@ -420,6 +480,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape); ...@@ -420,6 +480,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
size_t roundup(const size_t value, const size_t multiple); size_t roundup(const size_t value, const size_t multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);
// unpack the PhiloxCudaState into CUDA tensor
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr);
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
namespace std { namespace std {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment