Unverified Commit 3f5b4754 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[Core][PyTorch] NVFP4 recipe (#2177)



* Add NVFP4 recipe
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add MathDx dependency to GitHub builds
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Suggestions from GitHub Copilot
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move 2x shape logic from core to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compilation errors with CUDA 12.1
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* SM 70 is not supported in CUDA 13
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Typo
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Revert "Move 2x shape logic from core to PyTorch"

This reverts commit f8b2a2d0111d9af690b43bb98ae448d9a430a185.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Added dequantize kernel for FP4
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 support with fusible ops

Use logical tensor dims for PyTorch NVFP4 tensors. Temporarily add unfused dequantize impl. Fix bug where NVFP4 recipe was not configurable.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix logic for 2x shapes and move to PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CG test model config
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Debug NVFP4 tensor size function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Proper handling of the RNG state
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Test SR properly
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix workspace size for GEMM heuristic.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix compile error in C++ NVFP4 test

Some some numeric errors when blocks are all zero.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix distrbuted test problem shape
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* proper assert dim for low precision AG TP
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up duplicated code in nvfp4_utils.cuh
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pylint: disable=unused-argument
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* `nvte_cublas_gemm_v2` to take alpha pointer (#12)

* make nvte_cublas_gemm_v2 to take alpha/beta pointers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* users are expected to pass a valid C_tensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* typos
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* API to have const float* alpha
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Minor tweaks

Support arbitrary beta scales. Increase workspace to be aligned to 128 bytes.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug IMA with alpha pointer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Support fused amax kernels with NVFP4 quantization
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable fused amax with cuDNN LayerNorm kernel
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add NVFP4 cases to distributed tests for TE ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change assert to NVTE_CHECK in the hadamard cast fusion
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix compile error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use global thread IDs for Philox subsequences
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add shape checks for NVFP4 cast kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Do not fuse amax if cuDNN normalization is forced by envvar
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarFrank Sun <frsun@nvidia.com>
Co-authored-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTeddy Do <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dfeef1a2
......@@ -16,6 +16,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
......@@ -29,6 +30,7 @@ from .base import (
from ..fp8 import FP8GlobalStateManager
from ..utils import (
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
cast_if_needed,
clear_tensor_data,
divide,
......@@ -53,7 +55,7 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore
from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
......@@ -135,6 +137,8 @@ class _LayerNormLinear(torch.autograd.Function):
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
......@@ -144,6 +148,7 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat = inp
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer)
# Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast")
......@@ -157,7 +162,6 @@ class _LayerNormLinear(torch.autograd.Function):
weight_requires_grad = weight.requires_grad
backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
......@@ -190,11 +194,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
experimental = is_experimental(input_quantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental
)
# Apply normalization
......@@ -240,7 +246,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
if not with_quantized_norm:
# experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
......@@ -1422,6 +1429,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None:
......@@ -1526,11 +1535,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
......@@ -1763,6 +1768,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
......
......@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
......@@ -40,6 +41,7 @@ from ..utils import (
init_method_constant,
cast_if_needed,
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
clear_tensor_data,
requires_grad,
needs_quantized_gemm,
......@@ -64,6 +66,7 @@ from ..tensor.float8_tensor import (
Float8Tensor,
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
......@@ -114,7 +117,8 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
if recipe.float8_current_scaling() or recipe.float8_block_scaling():
# TODO(ksivaman): Fuse nvfp4 act once kernel is available.
if recipe.float8_current_scaling() or recipe.float8_block_scaling() or recipe.nvfp4():
return {
"gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
......@@ -211,6 +215,7 @@ class _LayerNormMLP(torch.autograd.Function):
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer)
activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
......@@ -258,11 +263,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
experimental = is_experimental(fc1_input_quantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental
)
# Apply normalization
......@@ -302,7 +309,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None
if fp8 or debug:
quantizer = fc1_input_quantizer
if not with_quantized_norm:
# experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
......@@ -548,6 +556,7 @@ class _LayerNormMLP(torch.autograd.Function):
if not fc2_weight.requires_grad:
clear_tensor_data(act_out)
act_out = None
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
ln_weight,
......@@ -673,6 +682,7 @@ class _LayerNormMLP(torch.autograd.Function):
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
......@@ -1014,7 +1024,10 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fp8:
# TODO float8 blockwise current scaling has no bgrad fusion for now
if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer):
# TODO(ksivaman): Re-add fusion once kernel is available.
if isinstance(
ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer)
):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
else:
......@@ -1690,6 +1703,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None:
......@@ -1908,7 +1923,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage(
rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
columnwise=isinstance(
fc2_input_quantizer,
(MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer),
),
)
fc1_input_quantizer.internal = True
if fp8_output:
......@@ -2113,6 +2131,28 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
return [self.fc1_weight, self.fc2_weight]
......
......@@ -25,7 +25,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import noop_cat, WeightGradStore
from ._common import noop_cat, WeightGradStore, get_module_quantizers
from ..fp8 import FP8GlobalStateManager
from ..utils import (
cast_if_needed,
......@@ -35,6 +35,7 @@ from ..utils import (
requires_grad,
needs_quantized_gemm,
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
nvtx_range_pop,
nvtx_range_push,
)
......@@ -65,6 +66,7 @@ from ..tensor.quantized_tensor import (
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_experimental
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
......@@ -151,6 +153,9 @@ class _Linear(torch.autograd.Function):
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG
# experimental recipe check
experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer)
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
......@@ -161,6 +166,7 @@ class _Linear(torch.autograd.Function):
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer)
if save_original_input:
assert not isinstance(
input_quantizer, Float8Quantizer
......@@ -172,7 +178,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase):
if not isinstance(inputmat, QuantizedTensorBase) and not experimental:
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
......@@ -442,6 +448,7 @@ class _Linear(torch.autograd.Function):
ctx.main_grad_func = lambda: weight.main_grad
ctx.debug = debug
ctx.experimental = experimental
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None
......@@ -609,7 +616,7 @@ class _Linear(torch.autograd.Function):
if isinstance(inputmat, QuantizedTensorBase):
# Input tensor is already quantized
pass
elif ctx.debug:
elif ctx.debug or ctx.experimental:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
......@@ -698,6 +705,7 @@ class _Linear(torch.autograd.Function):
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8,
......@@ -1326,6 +1334,8 @@ class Linear(TransformerEngineBaseModule):
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
elif recipe.nvfp4():
self._customize_quantizers_nvfp4(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False):
......@@ -1410,12 +1420,7 @@ class Linear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
)
quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
......@@ -1655,6 +1660,28 @@ class Linear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert recipe.nvfp4(), "Incorrect recipe."
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# customize input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
else:
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8 and not self.fp8_calibration:
......
......@@ -926,6 +926,7 @@ class BasicLinear(BasicOperation):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
......@@ -940,6 +941,13 @@ class BasicLinear(BasicOperation):
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
if recipe.nvfp4():
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Get autocast dtype if needed
if torch.is_autocast_enabled():
......
......@@ -54,6 +54,7 @@ def get_all_tensor_types():
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor, NVFP4TensorBase
all_tensor_types = [
torch.Tensor,
......@@ -64,5 +65,7 @@ def get_all_tensor_types():
MXFP8TensorBase,
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
NVFP4Tensor,
NVFP4TensorBase,
]
return all_tensor_types
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Mixin class holding data specific for NVFP4Tensor"""
from __future__ import annotations
from collections.abc import Iterable
import functools
import math
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import torch
# import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
@functools.lru_cache(maxsize=None)
def _fp4_e2m1_vals(device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Values representable in FP4 E2M1 format"""
return torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
device=device,
dtype=dtype,
)
class _FromNVFP4Func(torch.autograd.Function):
"""Cast from NVFP4 to other dtype"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: NVFP4TensorBase,
dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Dequantize row-wise data
if tensor._rowwise_data is not None:
### TODO(tmoon): Debug dequantize kernel and remove unfused impl
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# Tensor properties
shape = list(tensor._rowwise_data.size())
shape[-1] *= 2
device = tensor._rowwise_data.device
# Convert FP4E2M1 values to FP32
data = tensor._rowwise_data.view(torch.uint8).to(torch.int32)
data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape)
data = _fp4_e2m1_vals(device, dtype=torch.float32)[data]
data = data.to(torch.float32).contiguous()
# Convert FP8E4M3 block scales to FP32
block_scales = tensor._rowwise_scale_inv
block_scales = block_scales.reshape(-1, block_scales.size(-1))
block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16]
block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32)
# Convert amax to FP32 tensor scale
tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max
# Apply scales
block_data = data.view(-1, 16)
block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1)
return data.to(dtype)
if tensor._columnwise_data is not None:
raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!")
raise ValueError("Attempted to dequantize NVFP4 tensor with no data")
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
class NVFP4TensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of NVFP4Tensor.
NVFP4Tensor inherits from the PyTorch tensor class and this mixin
class. If this class is instantiated directly, it has the same
data, lower CPU overhead, and less functionality. It should only
be instantiated directly for performance-critical internal usage.
"""
_rowwise_data: Optional[torch.Tensor]
_columnwise_data: Optional[torch.Tensor]
_quantizer: Optional[Quantizer]
_rowwise_scale_inv: torch.Tensor
_columnwise_scale_inv: torch.Tensor
_fp4_dtype: TE_DType
_amax_rowwise: torch.Tensor
_amax_columnwise: torch.Tensor
def __new__(
cls,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: torch.Tensor,
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: torch.Tensor,
amax_rowwise: torch.Tensor,
amax_columnwise: torch.Tensor,
fp4_dtype: TE_DType,
quantizer: Optional[Quantizer],
*args,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._fp4_dtype = fp4_dtype
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._amax_rowwise = amax_rowwise
instance._amax_columnwise = amax_columnwise
return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
self._amax_rowwise,
self._amax_columnwise,
):
if t is not None:
t.data = _empty_tensor()
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
"rowwise_data": self._rowwise_data,
"rowwise_scale_inv": self._rowwise_scale_inv,
"columnwise_data": self._columnwise_data,
"columnwise_scale_inv": self._columnwise_scale_inv,
"amax_rowwise": self._amax_rowwise,
"amax_columnwise": self._amax_columnwise,
"fp4_dtype": self._fp4_dtype,
"quantizer": self._quantizer,
}
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorBase]:
"""Prepare the tensor base for saving for backward"""
tensors = [
self._rowwise_data,
self._columnwise_data,
self._rowwise_scale_inv,
self._columnwise_scale_inv,
self._amax_rowwise,
self._amax_columnwise,
]
self._rowwise_data = None
self._columnwise_data = None
self._rowwise_scale_inv = None
self._columnwise_scale_inv = None
self._amax_rowwise = None
self._amax_columnwise = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1]
self._rowwise_scale_inv = tensors[2]
self._columnwise_scale_inv = tensors[3]
self._amax_rowwise = tensors[4]
self._amax_columnwise = tensors[5]
return tensors[6:]
def get_data_tensors(self):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Dequantize to a higher precision."""
return _FromNVFP4Func.forward(None, self, dtype)
def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]:
# pylint: disable=missing-function-docstring
# Infer tensor shape
shape = None
if self._rowwise_data is not None:
byte_shape = list(self._rowwise_data.size())
shape = byte_shape[:-1] + [byte_shape[-1] * 2]
elif self._columnwise_data is not None:
warnings.warn("Attempting to get shape of NVFP4 tensor with only column-wise data.")
byte_shape = list(self._columnwise_data.size())
shape = byte_shape[1:-1] + [byte_shape[-1] * 2, byte_shape[0]]
if shape is None:
raise RuntimeError("Attempted to get shape of NVFP4 tensor with no data")
# Return shape or dim
if dim is None:
return torch.Size(shape)
return shape[dim]
def view(self, shape: torch.Size):
# pylint: disable=missing-function-docstring
# Return input tensor if view not needed
cur_shape = self.size()
if shape is None or shape == cur_shape:
return self
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(cur_shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if shape[-1] != cur_shape[-1]:
raise RuntimeError(
"NVFP4Tensor does not support reshaping inner dimension "
f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})"
)
# Reshape data
new_rowwise_data = None
new_columnwise_data = None
if self._rowwise_data is not None:
if shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = list(shape[:-1]) + [shape[-1] // 2]
new_rowwise_data = self._rowwise_data.view(byte_shape)
if self._columnwise_data is not None:
columnwise_shape = (shape[-1], math.prod(shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = self._columnwise_data.view(byte_shape)
# Construct tensor
return NVFP4TensorBase(
rowwise_data=new_rowwise_data,
rowwise_scale_inv=self._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=self._columnwise_scale_inv,
amax_rowwise=self._amax_rowwise,
amax_columnwise=self._amax_columnwise,
quantizer=self._quantizer,
fp4_dtype=self._fp4_dtype,
)
def __repr__(self):
data_rowwise = self.dequantize()
return (
"NVFP4TensorBase("
f"rowwise_scaled_data={data_rowwise},"
f"rowwise_scale_inv={self._rowwise_scale_inv},"
f"amax_rowwise={self._amax_rowwise},"
f"amax_columnwise={self._amax_columnwise},"
")"
)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
For the NVFP4 format, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
# Update row-scaled data
if rowwise_usage:
if self._rowwise_data is None:
raise RuntimeError(
"Requested row-wise usage, but NVFP4Tensor is missing row-scaled NVFP4 data"
)
if self._rowwise_scale_inv is None:
raise RuntimeError(
"Requested row-wise usage, but NVFP4Tensor is missing row-scaled scale-inverses"
)
if self._amax_rowwise is None:
raise RuntimeError(
"Requested row-wise usage, but NVFP4Tensor is missing per tensor"
" row-scaled scale-inverse"
)
else:
self._rowwise_data = None
self._rowwise_scale_inv = None
self._amax_rowwise = None
# Update column-scaled data
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, but NVFP4Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
"Requested column-wise usage, "
"but NVFP4Tensor is missing column-scaled scale-inverses"
)
if self._amax_columnwise is None:
raise RuntimeError(
"Requested column-wise usage, "
"but NVFP4Tensor is missing per tensor column-scaled scale-inverse"
)
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
self._amax_columnwise = None
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Tensor class with FP8 data"""
"""Tensor class with MXFP8 data"""
from __future__ import annotations
from collections.abc import Iterable
import math
......@@ -186,8 +186,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision. Can be inferred from fp8_meta if
provided.
precision.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with NVFP4 data"""
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple, Union
import functools
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe
from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type
from ..utils import (
canonicalize_process_group,
devices_match,
round_up_to_nearest_multiple,
)
from ._internal.nvfp4_tensor_base import NVFP4TensorBase, _FromNVFP4Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
aten = torch.ops.aten
def get_no_random_sign_vector() -> torch.Tensor:
"""Non-random sign vector for Hadamard transform."""
return torch.tensor([1], dtype=torch.float32)
def get_sign_from_vector(vector: torch.Tensor) -> int:
"""Convert sign vector to bitmask.
Used for random Hadamard transform.
"""
mask = 0
for i, v in enumerate(vector):
mask |= (v == -1) << i
return mask
def get_wgrad_sign_vector() -> torch.Tensor:
"""Hard-coded random signs for Hadamard transform.
https://xkcd.com/221/
"""
return torch.tensor(
[1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1],
dtype=torch.float32,
)
def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
"""Construct a 16x16 Hadamard matrix."""
assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported."
hadamard_scale = 1 / math.sqrt(hadamard_dimension)
return (
torch.tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1],
[1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1],
[1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1],
[1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1],
[1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1],
[1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1],
[1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1],
[1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1],
[1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1],
[1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1],
[1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1],
[1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1],
[1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1],
[1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1],
[1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1],
],
dtype=torch.float32,
)
* hadamard_scale
)
@functools.lru_cache(maxsize=None)
def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor:
"""Construct matrix used in random Hadamard transform."""
hadamard_dimension = 16
if with_random_sign_mask:
signs = get_wgrad_sign_vector()
else:
signs = get_no_random_sign_vector()
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32)
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension)
return rht_matrix.to(dtype=torch.bfloat16).cuda()
@functools.lru_cache(maxsize=None)
def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int:
"""Sign mask for random Hadamard transform."""
if with_random_sign_mask:
return get_sign_from_vector(get_wgrad_sign_vector())
return 0
class NVFP4Quantizer(Quantizer):
"""Builder class for NVFP4 tensors with NV block scaling"""
dtype: TE_DType
"""Random Hadamard Transform"""
with_rht: bool
with_post_rht_amax: bool
"""amax reduction options"""
with_amax_reduction: bool
amax_reduction_group: Optional[dist_group_type]
"""2D block scaling, only applicable for weights."""
with_2d_quantization: bool
"""Stochastic rounding, only applicable for gradients."""
stochastic_rounding: bool
"""RHT matrix random sign mask"""
rht_matrix_random_sign_mask_t: int
rht_matrix: torch.Tensor
def __init__(
self,
fp4_dtype: TE_DType = tex.DType.kFloat4E2M1,
rowwise: bool = True,
columnwise: bool = True,
with_amax_reduction: bool = False,
amax_reduction_group: Optional[dist_group_type] = None,
with_rht: bool = False,
with_post_rht_amax: bool = False,
with_2d_quantization: bool = False,
stochastic_rounding: bool = False,
with_random_sign_mask: bool = True,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp4_dtype
self.with_rht = with_rht
self.with_post_rht_amax = with_post_rht_amax
self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group
self.with_2d_quantization = with_2d_quantization
self.stochastic_rounding = stochastic_rounding
self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask)
self.rht_matrix = get_rht_matrix(with_random_sign_mask)
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type."
# Make sure input is in expected format
if not devices_match(src.device, dst.device):
src = src.to(device=dst.device)
if not src.is_contiguous():
src = src.contiguous()
# Launch cast kernel
tex.quantize(src, self, dst, noop_flag)
return dst
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
return False
if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0:
return False
if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0:
return False
return True
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For NVFP4 1D blockwise quantization, blocksize is 16
- If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4))
- If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4))
Swizzle kernel will be performed before GEMM to suit the need of CuBLAS.
CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
M, K = 1, 1
M = math.prod(shape[:-1])
K = shape[-1]
if columnwise:
outer = round_up_to_nearest_multiple(K, 128)
inner = round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4)
return (outer, inner)
# rowwise
outer = round_up_to_nearest_multiple(M, 128)
inner = round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4)
return (outer, inner)
@staticmethod
def get_columnwise_shape(shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise quantization.
For NVFP4 columnwise quantization, it's performing 16x1 quantization block scaling.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
"""
if len(shape) == 0:
return tuple()
# and then after AG, a reorganize kernel will be called to restore the shape
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
return tuple(colwise_shape)
@staticmethod
def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]:
"""Convert shape for FP4 data by dividing the last dimension by 2"""
shape = list(shape)
shape[-1] = shape[-1] // 2
return tuple(shape)
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> NVFP4Tensor:
# Canonicalize tensor attributes
if device is None:
device = torch.device("cuda")
assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, (
f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by"
f" {NVFP4_BLOCK_SCALING_SIZE}"
)
flat_first_dim = math.prod(shape[:-1])
assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, (
f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by"
f" {NVFP4_BLOCK_SCALING_SIZE}"
)
# Allocate FP4 data
data = None
scale_inv = None
amax_rowwise = None
if self.rowwise_usage:
data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device)
# Allocate per tensor scale inverse. FP32 format.
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device)
# Allocate FP8 data transpose if needed
columnwise_data = None
columnwise_scale_inv = None
amax_columnwise = None
if self.columnwise_usage:
# enforce 2D shape to avoid [S, B, H] shape and B and be 1
# and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
shape_2d = tuple([flat_first_dim, shape[-1]])
columnwise_data = torch.empty(
self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)),
dtype=torch.uint8,
device=device,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device
)
amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device)
# Construct FP8 tensor
return NVFP4Tensor(
shape=shape,
dtype=dtype,
rowwise_data=data,
rowwise_scale_inv=scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=amax_rowwise,
amax_columnwise=amax_columnwise,
fp4_dtype=self.dtype,
quantizer=self,
requires_grad=requires_grad,
)
def calibrate(self, tensor: torch.Tensor) -> None:
pass # Calibration is no-op
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return NVFP4BlockScaling
class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor):
"""Quantized tensor class with FP4 data
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP4. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
rowwise_data: torch.Tensor
Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP4, i.e. the scaling factor that must
be applied when casting from FP4 to higher
precision (rowwise).
columnwise_data: torch.Tensor, optional
Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv: torch.Tensor, optional
Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise: torch.Tensor, optional
Rowwise amax tracking tensor.
amax_columnwise: torch.Tensor, optional
Columnwise amax tracking tensor.
fp4_dtype: TE_DType
The FP4 data type used for quantization.
quantizer: Quantizer
The quantizer instance used for this tensor.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype, used in dequantize.
"""
# NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
amax_rowwise: Optional[torch.Tensor],
amax_columnwise: Optional[torch.Tensor],
fp4_dtype: TE_DType,
quantizer: Quantizer,
**kwargs,
):
instance = super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
columnwise_data,
columnwise_scale_inv,
amax_rowwise,
amax_columnwise,
fp4_dtype,
quantizer,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None):
return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})"
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from NVFP4Tensor
By default the resulting tensor's dtype is the
NVFP4Tensor's nominal dtype.
"""
# Convert PyTorch dtype to TE dtype
if dtype is None:
dtype = self.dtype
if torch.is_grad_enabled():
return _FromNVFP4Func.apply(self, dtype)
return _FromNVFP4Func.forward(None, self, dtype)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return NVFP4Quantizer()
def quantize_(
self,
tensor: torch.Tensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> NVFP4Tensor:
"""Update FP8 data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def detach(self) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
# TODO(ksivamani): Fix the detach bug
return NVFP4Tensor.make_like(self)
def clone(self) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
assert self._rowwise_data is not None
rowwise_data = self._rowwise_data.detach().clone()
columnwise_data = None
if self._columnwise_data is not None:
columnwise_data = self._columnwise_data.detach().clone()
return _IdentityFunc.apply(
self,
{
"rowwise_data": rowwise_data,
"columnwise_data": columnwise_data,
},
)
def view(self, *shape: Tuple[int]) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
return _ReshapeFunc.apply(self, shape)
def contiguous(
self,
memory_format: torch.memory_format = torch.contiguous_format,
) -> NVFP4Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if self._rowwise_data is not None and self._rowwise_data.is_contiguous(
memory_format=memory_format
):
return self
if self._columnwise_data is not None and self._columnwise_data.is_contiguous(
memory_format=memory_format
):
return self
raise ValueError("NVFP4Tensor does not support different memory formats!")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default:
if len(args) != 2:
raise RuntimeError("Unexpected args for view op (expected 2 args, got {len(args)})")
tensor = args[0]
shape = args[1]
if shape == list(tensor.size()):
return tensor.detach()
return tensor.view(shape)
# NVFP4 dequantize not supported. Add manual support for needed funcs.
if func in (aten.empty_like.default, aten.zero_.default):
tensor = args[0]
data_init_func = torch.zeros_like if func == aten.zero_.default else torch.empty_like
scale_inv_init_func = (
torch.ones_like if func == aten.zero_.default else torch.empty_like
)
if tensor._rowwise_data is not None:
rowwise_data = data_init_func(tensor._rowwise_data)
rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise)
else:
rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None
if tensor._columnwise_data is not None:
columnwise_data = data_init_func(tensor._columnwise_data)
columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise)
else:
columnwise_data, columnwise_scale_inv, amax_columnwise = (
None,
None,
None,
)
return NVFP4Tensor(
shape=tensor.shape,
dtype=tensor.dtype,
fp4_dtype=tensor._fp4_dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=amax_rowwise,
amax_columnwise=amax_columnwise,
quantizer=tensor._quantizer,
requires_grad=tensor.requires_grad,
)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
@classmethod
def _make_in_reduce_ex(
cls,
shape: torch.Size,
rowwise_data: torch.Tensor,
rowwise_scale_inv: torch.Tensor,
columnwise_data: torch.Tensor,
columnwise_scale_inv: torch.Tensor,
amax_rowwise: torch.Tensor,
amax_columnwise: torch.Tensor,
fp4_dtype: TE_DType,
dtype: torch.dtype,
quantizer: Quantizer,
) -> NVFP4Tensor:
"""Build NVFP4Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return NVFP4Tensor(
shape=shape,
dtype=dtype,
fp4_dtype=fp4_dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=amax_rowwise,
amax_columnwise=amax_columnwise,
quantizer=quantizer,
requires_grad=False,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling"""
return (
NVFP4Tensor._make_in_reduce_ex,
(
self.shape,
self._rowwise_data,
self._rowwise_scale_inv,
self._columnwise_data,
self._columnwise_scale_inv,
self._amax_rowwise,
self._amax_columnwise,
self._fp4_dtype,
self.dtype,
self._quantizer,
),
)
def _get_data(self) -> NVFP4Tensor:
"""Get tensor data property"""
return super().data
@torch.no_grad()
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Just takes FP8 data if setting from a NVFP4Tensor. Otherwise
casts to FP8.
"""
# Tensor device
new_device = tensor.device if tensor.is_cuda else self.device
if not devices_match(new_device, tensor.device):
tensor = tensor.to(device=new_device)
# Just copy FP8 data if other tensor is NVFP4Tensor
if isinstance(tensor, NVFP4Tensor):
if ( # pylint: disable=too-many-boolean-expressions
self.size() != tensor.size()
or self.stride() != tensor.stride()
or self.storage_offset() != tensor.storage_offset()
or self.dtype != tensor.dtype
or self.layout != tensor.layout
or not devices_match(self.device, new_device)
):
dummy_tensor = torch.Tensor._make_wrapper_subclass(
NVFP4Tensor,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
device=new_device,
)
# pylint: disable=unnecessary-dunder-call
super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor)
self._rowwise_data = tensor._rowwise_data
self._columnwise_data = tensor._columnwise_data
self._quantizer = tensor._quantizer
self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._amax_rowwise = tensor._amax_rowwise
self._amax_columnwise = tensor._amax_columnwise
return
# Quantize to FP8
assert self._quantizer is not None, "Can't quantize without a quantizer"
self._quantizer.update_quantized(tensor, self)
if self.requires_grad != tensor.requires_grad:
self.requires_grad_(requires_grad=tensor.requires_grad)
# Cast to FP8 when setting NVFP4Tensor.data
data = property(_get_data, _set_data)
class _ViewFunc(torch.autograd.Function):
"""View function
View the NVFP4Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: NVFP4Tensor,
shape: Optional[list[int]] = None,
) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
cur_shape = tensor.shape
if ctx is not None:
ctx.shape = cur_shape
if shape is None:
return tensor
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(cur_shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if shape[-1] != cur_shape[-1]:
raise RuntimeError(
"NVFP4Tensor does not support reshaping inner dimension "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
)
# Reshape data
new_rowwise_data = None
new_columnwise_data = None
if tensor._rowwise_data is not None:
if shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = list(shape[:-1]) + [shape[-1] // 2]
new_rowwise_data = tensor._rowwise_data.view(byte_shape)
if tensor._columnwise_data is not None:
columnwise_shape = (shape[-1], math.prod(shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = tensor._columnwise_data.view(byte_shape)
# Construct tensor
return NVFP4Tensor(
shape,
tensor.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=tensor._columnwise_scale_inv,
amax_rowwise=tensor._amax_rowwise,
amax_columnwise=tensor._amax_columnwise,
quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad,
)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, NVFP4Tensor):
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
if ctx.shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={ctx.shape} as byte array."
)
byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2]
new_rowwise_data = grad._rowwise_data.view(byte_shape)
if grad._columnwise_data is not None:
columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={ctx.shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = grad._columnwise_data.view(byte_shape)
dgrad = NVFP4Tensor(
ctx.shape,
grad.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=grad._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=grad._columnwise_scale_inv,
amax_rowwise=grad._amax_rowwise,
amax_columnwise=grad._amax_columnwise,
quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad,
)
return dgrad, None
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the NVFP4Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: NVFP4Tensor,
shape: Optional[list[int]] = None,
) -> NVFP4Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
cur_shape = tensor.shape
if ctx is not None:
ctx.shape = cur_shape
if shape is None:
return tensor
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(cur_shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if shape[-1] != cur_shape[-1]:
raise RuntimeError(
"NVFP4Tensor does not support reshaping inner dimension "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
)
# Reshape data
new_rowwise_data = None
new_columnwise_data = None
if tensor._rowwise_data is not None:
if shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = list(shape[:-1]) + [shape[-1] // 2]
new_rowwise_data = tensor._rowwise_data.reshape(byte_shape)
if tensor._columnwise_data is not None:
columnwise_shape = (shape[-1], math.prod(shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = tensor._columnwise_data.reshape(byte_shape)
# Construct tensor
return NVFP4Tensor(
shape,
tensor.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=tensor._columnwise_scale_inv,
amax_rowwise=tensor._amax_rowwise,
amax_columnwise=tensor._amax_columnwise,
quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad,
)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, NVFP4Tensor):
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
if ctx.shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent row-wise data for NVFP4 tensor "
f"with shape={ctx.shape} as byte array."
)
byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2]
new_rowwise_data = grad._rowwise_data.reshape(byte_shape)
if grad._columnwise_data is not None:
columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1]))
if columnwise_shape[-1] % 2 != 0:
raise ValueError(
"Cannot represent column-wise data for NVFP4 tensor "
f"with shape={ctx.shape} as byte array."
)
byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2)
new_columnwise_data = grad._columnwise_data.reshape(byte_shape)
dgrad = NVFP4Tensor(
ctx.shape,
grad.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=grad._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=grad._columnwise_scale_inv,
amax_rowwise=grad._amax_rowwise,
amax_columnwise=grad._amax_columnwise,
quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad,
)
return dgrad, None
return grad.view(ctx.shape), None
......@@ -264,6 +264,10 @@ class Quantizer(abc.ABC):
"""Returns True if the quantizer supports only rowwise all-gather"""
return False
def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized"""
return True
class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
......
......@@ -4,11 +4,13 @@
"""Helper functions for using fp8 tensors as weights"""
import os
from typing import Optional, Union
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from .quantized_tensor import QuantizedTensor
from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorBase
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
......@@ -450,3 +452,20 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
tex.fp8_block_scaling_partial_cast(
master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype
)
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware.
Returns False if x is a torch.Tensor.
"""
# Detect if the environment is experimental
if x is None:
return int(os.getenv("QAT_PARAMS", "0")) > 0
# Detect if the object is experimental
if isinstance(x, torch.Tensor):
return False
if not isinstance(x, (Quantizer, QuantizedTensorBase)):
raise AssertionError("Object must be a Quantizer or QuantizedTensorBase instance")
return hasattr(x, "experimental") and x.experimental
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFP4 padding kernels
TODO(ksivamani): Documentation
"""
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1),
],
key=["out_dim0", "out_dim1"],
)
@triton.jit
def zero_pad_kernel(
inp_ptr,
out_ptr,
in_dim0: tl.constexpr,
in_dim1: tl.constexpr,
out_dim0: tl.constexpr,
out_dim1: tl.constexpr,
in_s0,
in_s1,
out_s0,
out_s1,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Pads a tensor assuming it's a columnwise scaling inverse."""
# tile over OUTPUT coordinates
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols
om = offs_m[:, None]
on = offs_n[None, :]
# edge masking for output
out_mask = (om < out_dim0) & (on < out_dim1)
# valid input region is simply top-left (no offsets)
in_mask = (om < in_dim0) & (on < in_dim1)
# load valid input, else zero (masked load touches memory only where True)
x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0)
# store to output (only within bounds of the output tile)
tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask)
def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor:
"""Pads a tensor assuming it's a columnwise scaling inverse."""
assert inp.ndim == 2
dim0, dim1 = inp.shape
pad_x = (128 - dim0 % 128) % 128
pad_y = (4 - dim1 % 4) % 4
out_x = dim0 + pad_x
out_y = dim1 + pad_y
out = torch.empty((out_x, out_y), device=inp.device, dtype=inp.dtype)
in_s0, in_s1 = inp.stride()
out_s0, out_s1 = out.stride()
BLOCK_M, BLOCK_N = 128, 128
grid = (triton.cdiv(out_x, BLOCK_M), triton.cdiv(out_y, BLOCK_N))
zero_pad_kernel[grid](
inp,
out,
dim0,
dim1,
out_x,
out_y,
in_s0,
in_s1,
out_s0,
out_s1,
)
return out
......@@ -11,8 +11,8 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from .tensor.quantized_tensor import Quantizer
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
......@@ -441,6 +441,16 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
)
def assert_dim_for_all_gather(
tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer
) -> None:
"""Assert that tensor dimensions are supported for all-gather"""
if with_all_gather:
assert quantizer.is_quantizable(tensor), (
"All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__
)
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
......@@ -460,6 +470,8 @@ def is_non_tn_fp8_gemm_supported() -> bool:
@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
import transformer_engine.pytorch.cpp_extensions as ext
encoded_version = ext.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment