Unverified Commit 1df4a69f authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Enable reference Current Scaling recipe (#2368)



* Enable reference current scaling recipe
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* linter
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Test ref vs native
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e1edaaec
...@@ -8,9 +8,15 @@ import torch ...@@ -8,9 +8,15 @@ import torch
import pytest import pytest
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.custom_recipes.quantization import MMParams
from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import (
CurrentScalingQuantizerRef,
)
# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory # read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
...@@ -749,6 +755,132 @@ class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase): ...@@ -749,6 +755,132 @@ class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase):
) )
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8CurrentScalingNativeVsRef:
@staticmethod
def _make_quantizers(rowwise=True, columnwise=True):
# TE native FP8 current scaling quantizer
te_quant = te.Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=torch.device("cuda"),
rowwise=rowwise,
columnwise=columnwise,
)
# Reference quantizer
ref_quant = CurrentScalingQuantizerRef(
dtype=torch.float8_e4m3fn,
rowwise=rowwise,
columnwise=columnwise,
pow_2_scales=False,
eps=0.0,
)
return te_quant, ref_quant
@pytest.mark.parametrize(
"M, N, dtype",
[
(128, 256, torch.bfloat16),
],
ids=["rowwise"],
)
def test_current_scaling_quantization_versus_reference(self, M, N, dtype):
device = "cuda"
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn((M, N), dtype=dtype, device=device)
te_quant, ref_quant = self._make_quantizers(rowwise=True, columnwise=False)
# Native TE quantization
x_te = te_quant(x)
assert x_te._data is not None
qx_native = x_te._data.view(dtype=torch.float8_e4m3fn)
sx_native = x_te._scale_inv
# Reference quantization
x_ref = ref_quant.quantize(x)
qx_ref = x_ref.data
sx_ref = x_ref.scale
# Byte-for-byte equality on data and exact scale_inv match
torch.testing.assert_close(qx_native, qx_ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_native, sx_ref, atol=0.0, rtol=0.0)
@pytest.mark.parametrize(
"M, K, N, out_dtype, accumulate",
[
(128, 256, 96, torch.bfloat16, False),
(64, 128, 64, torch.float32, True),
],
ids=["bf16_no_acc", "fp32_acc"],
)
def test_current_scaling_gemm_versus_reference(self, M, K, N, out_dtype, accumulate):
device = "cuda"
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn((M, K), dtype=torch.bfloat16, device=device)
w = torch.randn((N, K), dtype=torch.bfloat16, device=device)
out = torch.randn((M, N), dtype=out_dtype, device=device) if accumulate else None
te_quant_x, ref_quant = self._make_quantizers(rowwise=True, columnwise=True)
te_quant_w, _ = self._make_quantizers(rowwise=True, columnwise=True)
# Native TE quantization (direct)
qx_native = te_quant_x(x)
qw_native = te_quant_w(w)
# Prepare inputs for reference qgemm
assert qx_native._data is not None and qw_native._data is not None
qx_data = qx_native._data.view(dtype=torch.float8_e4m3fn)
qw_data = qw_native._data.view(dtype=torch.float8_e4m3fn)
sx = qx_native._scale_inv
sw = qw_native._scale_inv
# Reference GEMM
m_params = MMParams(out_dtype=out_dtype, use_split_accumulator=False)
y_ref = ref_quant.qgemm(
qx=qx_data,
qw=qw_data,
m_params=m_params,
out_dtype=out_dtype,
sx=sx,
sw=sw,
bias=None,
out=out.clone() if accumulate else None,
accumulate=accumulate,
gemm_type=None,
qresult_x=None,
qresult_w=None,
)
# Native TE GEMM
# return type is out, bias_grad, gelu_input, extra_output
y_native = tex.generic_gemm(
qw_native, # A
True, # transa (treat (N,K) as (K,N))
qx_native, # B
False, # transb
out.clone() if accumulate else None,
None, # out quantizer
TE_DType[out_dtype],
None, # bias
TE_DType[torch.bfloat16],
False, # use_gelu
None, # gelu_input
False, # use_grad
torch.empty(0, dtype=torch.uint8, device=device),
0,
accumulate,
False, # use_split_accumulator
)[0]
torch.testing.assert_close(y_native, y_ref, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase):
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Current scaling recipe reference implementation."""
import dataclasses
import math
from typing import Optional, Tuple, Iterable
import torch
from transformer_engine.pytorch.custom_recipes import quantization
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer
def current_scaling_ref_quantizer_factory(role):
"""Factory function for current scaling reference quantizer.
Usage with CustomRecipe and autocast:
custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_ref_quantizer_factory)
with autocast(recipe=custom_recipe):
output = model(input)
"""
if role in ("linear_input", "linear_weight"):
dtype = torch.float8_e4m3fn
elif role in ("linear_output", "linear_grad_output"):
dtype = torch.float8_e5m2
else:
return None
return CurrentScalingQuantizerRef(
dtype=dtype,
rowwise=True,
columnwise=True,
pow_2_scales=False,
eps=0.0,
)
@dataclasses.dataclass
class CurrentScalingTensorRef(QuantizedTensorStorage):
"""Reference implementation of current scaling quantized tensor"""
data: Optional[torch.Tensor] = None
scale: Optional[torch.Tensor] = None
data_t: Optional[torch.Tensor] = None
scale_t: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None
device: Optional[torch.device] = None
quant_dtype: Optional[torch.dtype] = None
original_shape: Optional[Tuple[int, ...]] = None
_quantizer: Optional[Quantizer] = None
@property
def custom(self) -> bool:
"""Flag to indicate this quantized tensor is custom."""
return True
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the quantization result for saving for backward"""
tensors = [self.data, self.data_t, self.scale, self.scale_t]
self.data = None
self.data_t = None
self.scale = None
self.scale_t = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the quantization result from the saved tensors"""
self.data = tensors[0]
self.data_t = tensors[1]
self.scale = tensors[2]
self.scale_t = tensors[3]
return tensors[4:]
# Compatibility
@property
def _data(self):
return self.data
@_data.setter
def _data(self, value):
self.data = value
@property
def _scale_inv(self):
return self.scale
@_scale_inv.setter
def _scale_inv(self, value):
self.scale = value
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"dtype={self.dtype}, "
f"device={self.device}, "
f"quant_dtype={self.quant_dtype}, "
f"original_shape={self.original_shape}"
")"
)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""Generate or remove quantized data based on provided usage."""
has_data = self.data is not None
has_data_transpose = self.data_t is not None
needs_data = has_data
needs_data_transpose = has_data_transpose
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self.data = None
if not needs_data_transpose:
self.data_t = None
def _create_transpose(self):
"""Create transposed quantized tensor"""
if not self.data.is_contiguous():
self.data = self.data.contiguous()
self.data_t = self.data.t().contiguous()
self.scale_t = self.scale
def size(self, *args, **kwargs):
"""Get the size of the quantized tensor"""
if self.data is not None:
return self.data.size(*args, **kwargs)
size = self.data_t.size(*args, **kwargs)
return torch.Size([size[-1], math.prod(size[:-1])])
def _scale_from_amax_tensor(
x_dtype: torch.dtype,
amax: torch.Tensor,
quant_dtype: torch.dtype,
*,
eps: float,
pow_2_scales: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Derives quantization and dequantization from amax and options.
Reference implementation for scale calculation.
Returns:
- scale: quantization scales
- scale_inv: dequantization scales
- amax: Amax tensor with updates made for extrema values.
"""
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Take care of inf before pow_2_scales
scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale)
if pow_2_scales:
_, exp = torch.frexp(scale)
exp = exp - 1
assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv, amax
class CurrentScalingQuantizerRef(Quantizer):
"""Reference implementation of current scaling quantizer"""
def __init__(
self,
dtype: torch.dtype,
rowwise: bool = True,
columnwise: bool = True,
pow_2_scales: bool = False,
eps: float = 0.0,
):
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.internal = True
self.dtype = dtype
self.pow_2_scales = pow_2_scales
self.eps = eps
self.with_amax_reduction = False
self.amax_reduction_group = None
@property
def custom(self) -> bool:
"""Flag to indicate this quantizer is custom."""
return True
@property
def supports_allgather_fp8(self) -> bool:
"""Flag to indicate this quantizer supports allgather fp8"""
return True
@classmethod
def compute_scale(
cls,
x: torch.Tensor,
quant_dtype: torch.dtype,
eps=0.0,
pow_2_scales: bool = False,
):
"""Compute the scale from the amax tensor"""
# Use float32 for computation
x_fp32 = x.to(torch.float32)
if x_fp32.numel() == 0:
amax = torch.empty(1, dtype=torch.float32, device=x.device)
else:
amax = torch.amax(torch.abs(x_fp32)).view(1)
return _scale_from_amax_tensor(
x.dtype,
amax=amax,
quant_dtype=quant_dtype,
eps=eps,
pow_2_scales=pow_2_scales,
)
def _quantize(self, tensor: torch.Tensor) -> Tuple[
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""
Python implementation of quantization (c++ kernel can be used as an option instead).
Parameters
----------
tensor : torch.Tensor
Input tensor to quantize (should be 2D)
Returns
-------
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
(qx, sx, qx_t, sx_t) where:
- qx: quantized data in row-major order (if rowwise_usage), None otherwise
- sx: empty scale tensor for qx (if rowwise_usage), None otherwise
- qx_t: quantized data in column-major order (if columnwise_usage), None otherwise
- sx_t: empty scale tensor for qx_t (if columnwise_usage), None otherwise
"""
# Handle amax reduction if enabled
if self.with_amax_reduction:
assert (
self.amax_reduction_group is not None
), "amax_reduction_group must be set when with_amax_reduction is True"
# Compute local amax
if tensor.numel() == 0:
amax = torch.empty(1, dtype=torch.float32, device=tensor.device)
else:
amax = torch.amax(torch.abs(tensor)).view(1).to(torch.float32)
# Reduce amax across all ranks
torch.distributed.all_reduce(
amax, group=self.amax_reduction_group, op=torch.distributed.ReduceOp.MAX
)
# Compute scale using the global amax
scale, scale_inv, _ = _scale_from_amax_tensor(
tensor.dtype,
amax=amax,
quant_dtype=self.dtype,
eps=self.eps,
pow_2_scales=self.pow_2_scales,
)
else:
# compute scale factor using local amax
scale, scale_inv, _ = self.compute_scale(
tensor,
self.dtype,
eps=self.eps,
pow_2_scales=self.pow_2_scales,
)
qx: Optional[torch.Tensor] = (tensor.float() * scale).to(self.dtype)
sx: Optional[torch.Tensor] = scale_inv
# transpose if needed
if self.columnwise_usage:
assert qx is not None
qx_t = qx.t().contiguous()
sx_t = sx
else:
qx_t, sx_t = None, None
if not self.rowwise_usage:
qx = None
sx = None
return qx, sx, qx_t, sx_t
def quantize(
self,
tensor: torch.Tensor,
**kwargs, # pylint: disable=unused-argument
) -> CurrentScalingTensorRef:
# sanity checks
assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype."
# Make it work with 3D tensors
original_shape = tensor.shape
if tensor.ndim > 2:
tensor = tensor.view(-1, tensor.shape[-1])
qx, sx, qx_t, sx_t = self._quantize(tensor)
return CurrentScalingTensorRef(
data=qx,
scale=sx,
data_t=qx_t,
scale_t=sx_t,
dtype=tensor.dtype,
device=tensor.device,
quant_dtype=self.dtype,
_quantizer=self,
original_shape=original_shape,
)
def dequantize(
self, tensor: torch.Tensor, scale: torch.Tensor, dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
"""Dequantize the quantized tensor"""
tensor = tensor.to(torch.float32) * scale
if dtype is None:
return tensor
return tensor.to(dtype)
def qgemm(
self,
qx: torch.Tensor,
qw: torch.Tensor,
m_params: quantization.MMParams,
out_dtype: torch.dtype,
sx: torch.Tensor,
sw: torch.Tensor,
bias: torch.Tensor | None = None,
out: torch.Tensor | None = None,
accumulate: bool = False,
gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP, # pylint: disable=unused-argument
qresult_x: QuantizedTensorStorage | None = None, # pylint: disable=unused-argument
qresult_w: QuantizedTensorStorage | None = None, # pylint: disable=unused-argument
) -> torch.Tensor:
"""Python implementation of quantized gemm."""
M, K = qx.shape
N, _ = qw.shape
if M == 0 or K == 0 or N == 0:
if accumulate:
assert out is not None
y = out
else:
y = torch.zeros((M, N), dtype=out_dtype, device=qx.device)
if bias is not None:
y += bias
return y
# cublas fp8 gemm does not support fp32 bias
use_bias_in_gemm = (
bias is not None and out_dtype != torch.float32 and bias.dtype != torch.float32
)
# Run quantized gemm: y = qw * qx
scaled_mm_res = torch._scaled_mm(
qx,
qw.transpose(-1, -2),
scale_a=sx,
scale_b=sw,
out_dtype=out_dtype,
use_fast_accum=not m_params.use_split_accumulator,
bias=bias if use_bias_in_gemm else None,
)
y = scaled_mm_res[0] if isinstance(scaled_mm_res, tuple) else scaled_mm_res
if bias is not None and not use_bias_in_gemm:
# Check number of elements in bias tensor because it can be an empty tensor
if bias.numel():
y += bias
if accumulate:
assert out is not None, "Output tensor must be provided for accumulation."
out.add_(y)
y = out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y
def transpose_qresult(self, qresult: CurrentScalingTensorRef) -> CurrentScalingTensorRef:
"""Python implementation of transpose qresult."""
qx = qresult.data
scale = qresult.scale
assert qresult.data_t is None
assert qresult.scale_t is None
assert qx is not None
qx_t = qx.transpose(-2, -1).contiguous()
scale_t = scale
qresult.data_t = qx_t
qresult.scale_t = scale_t
return qresult
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensorStorage,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensorStorage:
"""Update the quantized tensor with the given tensor in-place
Parameters
----------
src: torch.Tensor
Source tensor to copy from
dst: ExperimentalQuantizedTensor
Destination ExperimentalQuantizedTensor to update
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
# Handle noop flag
if noop_flag is not None and noop_flag.item() != 0:
return dst
# Make sure input is in expected format
if not src.is_contiguous():
src = src.contiguous()
# Store the original shape and reshape for processing
original_shape = src.shape
if src.ndim > 2:
src = src.view(-1, src.shape[-1])
qx, sx, qx_t, sx_t = self._quantize(src)
# Update the destination with new data
dst.data = qx
dst.scale = sx
dst.data_t = qx_t
dst.scale_t = sx_t
dst.dtype = src.dtype
dst.quant_dtype = self.dtype
dst.original_shape = original_shape
return dst
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False, # pylint: disable=unused-argument
) -> CurrentScalingTensorRef:
assert len(shape) == 2, "shape is not 2d"
# Canonicalize tensor attributes
if device is None:
device = torch.device("cuda")
# Allocate quantized data
qx = torch.empty(shape, dtype=self.dtype, device=device)
sx = torch.empty(1, dtype=torch.float32, device=device)
# Allocate quantized data transpose if needed
qx_t = None
sx_t = None
if self.columnwise_usage:
inner_dim = qx.size(-1)
qx_t = torch.empty(
inner_dim,
qx.numel() // inner_dim,
dtype=self.dtype,
device=device,
)
sx_t = torch.empty(1, dtype=torch.float32, device=device)
# Construct quantized tensor
return CurrentScalingTensorRef(
data=qx,
scale=sx,
data_t=qx_t,
scale_t=sx_t,
dtype=dtype,
device=device,
quant_dtype=self.dtype,
_quantizer=self,
original_shape=shape,
)
...@@ -18,9 +18,9 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): ...@@ -18,9 +18,9 @@ def nvfp4_ref_rht_2d_quantizer_factory(role):
""" """
Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights). Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights).
Usage with CustomRecipe and fp8_autocast: Usage with CustomRecipe and autocast:
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
with fp8_autocast(fp8_recipe=custom_recipe): with autocast(fp8_recipe=custom_recipe):
output = model(input) output = model(input)
""" """
if role == "linear_input": if role == "linear_input":
...@@ -338,7 +338,7 @@ def get_wgrad_sign_vector() -> torch.Tensor: ...@@ -338,7 +338,7 @@ def get_wgrad_sign_vector() -> torch.Tensor:
class NVFP4QuantizerRef(Quantizer): class NVFP4QuantizerRef(Quantizer):
"""NVFP4 quantizer for middleware between Transformer Engine and Kitchen""" """Reference implementation of NVFP4 quantizer"""
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment