Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -7,7 +7,16 @@ import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear
from transformer_engine.pytorch import (
DotProductAttention,
TransformerLayer,
Linear,
GroupedLinear,
NVFP4Quantizer,
autocast,
is_nvfp4_available,
)
from transformer_engine.common import recipe
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
......@@ -17,9 +26,13 @@ model_configs = {
"small": ModelConfig(2, 10, 2, 16),
}
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"])
@pytest.mark.parametrize(
"module", ["TransformerLayer", "DotProductAttention", "Linear", "GroupedLinear"]
)
def test_current_device(model, module):
"""Test cases where current device is different from tensor device"""
......@@ -42,7 +55,29 @@ def test_current_device(model, module):
self_attn_mask_type="padding",
device=f"cuda:{tensor_device}",
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
seqlens_q = torch.randint(
1,
config.max_seqlen_q,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_q = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
seqlens_kv = torch.randint(
1,
config.max_seqlen_kv,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_kv = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
num_tokens = cu_seqlens_q[-1]
args = [
torch.randn(
(num_tokens, config.hidden_size),
......@@ -51,37 +86,55 @@ def test_current_device(model, module):
requires_grad=True,
)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention":
elif module == "DotProductAttention":
model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
seqlens_q = torch.randint(
1,
config.max_seqlen_q,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_q = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
seqlens_kv = torch.randint(
1,
config.max_seqlen_kv,
[config.batch_size],
dtype=torch.int32,
device=f"cuda:{tensor_device}",
)
cu_seqlens_kv = torch.zeros(
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
num_tokens = cu_seqlens_q[-1]
args = [
torch.randn(
num_tokens,
config.num_heads,
config.head_dim_qk,
dtype=dtype,
device=tensor_device,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
for _ in range(3)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)]
bwd_args = [
torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=f"cuda:{tensor_device}")
]
elif module == "Linear":
model = Linear(
config.hidden_size,
......@@ -97,6 +150,24 @@ def test_current_device(model, module):
requires_grad=True,
)
]
elif module == "GroupedLinear":
num_gemms = 4
model = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q * config.batch_size * (num_gemms - 1), config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
),
[0] + [config.max_seqlen_q * config.batch_size] * (num_gemms - 1), # Empty first split.
]
current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs)
......@@ -118,3 +189,24 @@ def test_current_device(model, module):
assert (
tensor_device_grad == tensor_device
), "The gradient tensor should be the same as the input tensors!"
@pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4)
def test_nvfp4_rht_cache():
"""Ensure correct RHT cache for NVFP4."""
num_devices = torch.cuda.device_count()
assert num_devices > 1, "This test requires more than one GPU!"
# Populate cache on last device.
with torch.cuda.device(num_devices - 1):
_ = NVFP4Quantizer()
hidden_size = 128
dtype = torch.bfloat16
model = Linear(hidden_size, hidden_size, params_dtype=dtype)
inp = torch.randn(hidden_size, hidden_size, device=torch.cuda.current_device(), dtype=dtype)
fp4_recipe = recipe.NVFP4BlockScaling()
with autocast(recipe=fp4_recipe):
_ = model(inp)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
from transformer_engine.pytorch import LayerNormMLP
import pytest
torch.manual_seed(1234)
device = torch.device("cuda")
class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules"""
def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
x = input_
for module in self:
x = module(x, **kwargs)
return x
class ModelConfig:
def __init__(
self,
hidden_size: int = 128,
ffn_hidden_size: int = 512,
layers: int = 1,
):
self._hidden_size = hidden_size
self._ffn_hidden_size = ffn_hidden_size
self._layers = layers
def build(self):
ln_list, sln_list = [], []
for _ in range(self._layers):
ln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=False).to(device)
sln = LayerNormMLP(self._hidden_size, self._ffn_hidden_size, checkpoint=True).to(device)
with torch.no_grad():
sln.layer_norm_weight = torch.nn.Parameter(ln.layer_norm_weight.clone())
sln.layer_norm_bias = torch.nn.Parameter(ln.layer_norm_bias.clone())
sln.fc1_weight = torch.nn.Parameter(ln.fc1_weight.clone())
sln.fc2_weight = torch.nn.Parameter(ln.fc2_weight.clone())
sln.fc1_bias = torch.nn.Parameter(ln.fc1_bias.clone())
sln.fc2_bias = torch.nn.Parameter(ln.fc2_bias.clone())
ln_list.append(ln)
sln_list.append(sln)
ln_model = _Sequential(*ln_list)
sln_model = _Sequential(*sln_list)
return ln_model, sln_model
config = {
"small": ModelConfig(128, 512, 12),
"medium": ModelConfig(512, 2048, 12),
"large": ModelConfig(1024, 4096, 12),
"huge": ModelConfig(2048, 8192, 12),
}
seq_sizes = [2**7, 2**10, 2**14, 2**16]
def _warmup(model, tensor):
for _ in range(3):
model(tensor).sum().backward()
def _run_fwd(model, tensor):
torch.cuda.reset_peak_memory_stats(device)
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)
torch.cuda.synchronize()
start_mem = torch.cuda.memory_allocated(device)
start_time.record()
out = model(tensor)
end_time.record()
end_time.synchronize()
elapsed = start_time.elapsed_time(end_time)
peak_mem = torch.cuda.max_memory_allocated(device)
mem = float(peak_mem - start_mem)
return out, elapsed, mem
def _run_bwd(model, out):
model.zero_grad(set_to_none=False)
loss = out.sum()
torch.cuda.reset_peak_memory_stats(device)
start_time, end_time = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
enable_timing=True
)
torch.cuda.synchronize()
start_mem = torch.cuda.memory_allocated(device)
start_time.record()
loss.backward()
end_time.record()
end_time.synchronize()
elapsed = start_time.elapsed_time(end_time)
peak_mem = torch.cuda.max_memory_allocated(device)
mem = float(peak_mem - start_mem)
param_grads = _collect_param_grads(model)
return param_grads, elapsed, mem
def _max_diff(ref, other):
"""Return max absolute difference between two tensors or collections."""
if ref is None or other is None:
return 0.0
if isinstance(ref, (list, tuple)):
diffs = [_max_diff(r, o) for r, o in zip(ref, other)]
return max(diffs) if diffs else 0.0
return torch.max(torch.abs(ref.detach() - other.detach())).item()
def _collect_param_grads(model):
grads = {}
for name, param in model.named_parameters():
if param.grad is None:
continue
key = _param_key(name)
if key is not None:
grads[key] = param.grad.detach().clone()
return grads
def _param_key(name):
return name.split(".")[-1]
@pytest.mark.parametrize("size", config.keys())
@pytest.mark.parametrize("seq_size", seq_sizes)
def test_selective_activation_checkpoint(size, seq_size):
ln_model, sln_model = config[size].build()
data = torch.randn((seq_size, config[size]._hidden_size), device=device)
_warmup(ln_model, data)
ln_fwd_out, ln_fwd_time, ln_fwd_mem = _run_fwd(ln_model, data)
ln_grads, ln_bwd_time, ln_bwd_mem = _run_bwd(ln_model, ln_fwd_out)
_warmup(sln_model, data)
sln_fwd_out, sln_fwd_time, sln_fwd_mem = _run_fwd(sln_model, data)
sln_grads, sln_bwd_time, sln_bwd_mem = _run_bwd(sln_model, sln_fwd_out)
assert ln_fwd_mem > 6 * sln_fwd_mem, (
"selective activation checkpointing does not reduce forward memory by 6X, only by"
f" {ln_fwd_mem/sln_fwd_mem}!"
)
assert ln_bwd_time < sln_bwd_time, (
"selective activation activation checkpointing backward pass is NOT slower than native!"
f" got Native LayerNormMLP Backward Time: {ln_bwd_time} ms and Selective Activation"
f" Checkpointed LayerNormMLP Backward Time: {sln_bwd_time} ms"
)
diff = _max_diff(ln_fwd_out, sln_fwd_out)
assert diff == 0.0, f"outputs are not equal! maximum difference {diff}"
for key in [
"layer_norm_weight",
"layer_norm_bias",
"fc1_weight",
"fc1_bias",
"fc2_weight",
"fc2_bias",
]:
diff = _max_diff(ln_grads[key], sln_grads[key])
assert diff == 0.0, f"gradients for {key} are not equal! maximum difference: {diff}"
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -122,8 +122,15 @@ def check_nvfp4_gemm_versus_reference(
)
# Create reference quantized tensors needed by reference GEMM
x_nvfp4_ref = ref_quantizer.quantize(x)
w_nvfp4_ref = ref_quantizer.quantize(w)
# Reference GEMM is only rowwise.
if x_columnwise:
x_nvfp4_ref = ref_quantizer.quantize(x.t().contiguous())
else:
x_nvfp4_ref = ref_quantizer.quantize(x)
if w_columnwise:
w_nvfp4_ref = ref_quantizer.quantize(w.t().contiguous())
else:
w_nvfp4_ref = ref_quantizer.quantize(w)
# Reference GEMM using quantizer's qgemm method
y_ref = ref_quantizer.qgemm(
......@@ -155,6 +162,10 @@ def check_nvfp4_gemm_versus_reference(
use_grad = False
use_split_accumulator = False
if x_columnwise:
x_nvfp4_native.update_usage(rowwise_usage=False)
if w_columnwise:
w_nvfp4_native.update_usage(rowwise_usage=False)
# Native cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
......@@ -212,11 +223,11 @@ def check_nvfp4_gemm_versus_reference(
@pytest.mark.parametrize(
"is_x_columnwise, is_w_columnwise",
[
(False, False), # Only rowwise x rowwise is supported by reference GEMM
# Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization
# Columnwise layouts are not supported by the reference implementation
(False, False), # TN
(True, False), # NN
(True, True), # NT
],
ids=["rowxrow"],
ids=["rowxrow", "colxrow", "colxcol"],
)
def test_nvfp4_gemm_versus_reference(
M: int,
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py
# and also the test_nvfp4_rht_quantize_exact.py.
# Separate to make sure all the functionalities are working as expected.
# Otherwise reference implementation will get messy.
# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
# together with the quantization functionality.
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
import pytest
import torch
import random
import math
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
def generate_random_multiples_sum(total=8192, n=4, multiple=64):
if total % multiple != 0:
raise ValueError(f"Total ({total}) must be a multiple of {multiple}")
if (total // multiple) < n:
raise ValueError("Total too small for given n and multiple.")
# Work in units of multiples
total_units = total // multiple
# choose n−1 random cut points in [1, total_units−1)
cuts = sorted(random.sample(range(1, total_units), n - 1))
# convert to segment lengths
parts = (
[cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]]
)
# convert back to multiples
return [p * multiple for p in parts]
def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]:
least_multiple = 64
num_chunks = 4
split_sections = None
avg_split = M // num_chunks
if M == 0 or N == 0:
# all zeros
return [0] * num_chunks
if edge_cases == "regular":
split_sections = [avg_split] * num_chunks
elif edge_cases == "zero_tokens_front":
split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2]
elif edge_cases == "zero_tokens_end":
split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0]
elif edge_cases == "zero_tokens_middle":
split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2]
elif edge_cases == "random_uneven_split":
split_sections = generate_random_multiples_sum(M, num_chunks, least_multiple)
else:
raise ValueError(f"Invalid edge case: {edge_cases}")
# adds up the split_sections to make it M
assert sum(split_sections) == M, "The split_sections do not add up to M"
# make sure every split_section is a multiple of least_multiple
for split_section in split_sections:
assert (
split_section % least_multiple == 0
), "The split_sections are not multiples of least_multiple"
return split_sections
# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding
def get_nvfp4_scale_shape_no_padding(shape, columnwise):
M, K = 1, 1
M = math.prod(shape[:-1])
K = shape[-1]
if columnwise:
outer = K
inner = math.ceil(M / 16)
return (outer, inner)
# rowwise
outer = M
inner = math.ceil(K / 16)
return (outer, inner)
def reference_group_quantize(
x: torch.Tensor,
quantizers: list[NVFP4Quantizer],
split_sections: list[int],
return_identity: bool,
return_transpose: bool,
) -> torch.Tensor:
x_view = x.reshape(-1, x.size(-1))
x_chunks = torch.split(x, split_sections)
# rowwise quantization
x_qx = []
x_sx = []
x_amax_rowwise = []
# columnwise quantization
x_qx_t = []
x_sx_t = []
x_amax_colwise = []
for i in range(len(x_chunks)):
x_chunk = x_chunks[i]
x_nvfp4_res = quantizers[i](x_chunk)
if return_identity:
x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8))
x_sx.append(x_nvfp4_res._rowwise_scale_inv)
x_amax_rowwise.append(x_nvfp4_res._amax_rowwise)
else:
x_qx.append(None)
x_sx.append(None)
x_amax_rowwise.append(None)
if return_transpose:
x_qx_t.append(x_nvfp4_res._columnwise_data.view(dtype=torch.uint8))
x_sx_t.append(x_nvfp4_res._columnwise_scale_inv)
x_amax_colwise.append(x_nvfp4_res._amax_columnwise)
else:
x_qx_t.append(None)
x_sx_t.append(None)
x_amax_colwise.append(None)
return x_qx, x_sx, x_amax_rowwise, x_qx_t, x_sx_t, x_amax_colwise
def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None:
assert x.shape == y.shape
assert x.dtype == y.dtype
def check_group_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
return_identity: bool,
return_transpose: bool,
split_sections: list[int],
with_rht: bool = True,
with_post_rht_amax: bool = True,
with_random_sign_mask: bool = True,
) -> None:
te_dtype = tex.DType.kFloat4E2M1
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = torch.randn((M, N), dtype=x_dtype, device=device)
num_chunks = len(split_sections)
x_splits = torch.split(x, split_sections)
# Quantize
quantizers = [
NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=return_identity,
columnwise=return_transpose,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
for _ in range(len(split_sections))
]
x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = (
reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose)
)
split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers)
if return_identity:
x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs]
x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs]
x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs]
for i in range(len(x_qx)):
if split_sections[i] == 0:
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i])
assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i])
assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i])
else:
torch.testing.assert_close(
x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0
)
torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0)
valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False)
x_sx_valid = x_sx[i][: valid_scale_shape[0], : valid_scale_shape[1]]
x_sx_ref_valid = x_sx_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]]
torch.testing.assert_close(x_sx_valid, x_sx_ref_valid, atol=0.0, rtol=0.0)
if return_transpose:
x_qx_t = [
output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs
]
x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs]
x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs]
# assert with zero tolerance
for i in range(len(x_qx_t)):
if split_sections[i] == 0:
# then just assert the same shape and dtype because the buffer won't be zero out
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
else:
torch.testing.assert_close(
x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0
)
torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0)
valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True)
x_sx_t_valid = x_sx_t[i][: valid_scale_shape[0], : valid_scale_shape[1]]
x_sx_t_ref_valid = x_sx_t_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]]
torch.testing.assert_close(x_sx_t_valid, x_sx_t_ref_valid, atol=0.0, rtol=0.0)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# edge case, zero tokens for all
(0, 512),
# full tile cases
(256, 1024),
(1024, 256),
# larger sizes
(8192, 1024),
(16384, 8192),
(16384, 16384),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"edge_cases",
[
"regular",
"zero_tokens_front",
"zero_tokens_end",
"zero_tokens_middle",
"random_uneven_split",
],
)
@pytest.mark.parametrize(
"quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"]
)
@pytest.mark.parametrize(
"with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
)
@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
def test_rht_with_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
edge_cases: str,
quantize_mode: str,
with_random_sign_mask: bool,
with_rht: bool,
) -> None:
split_sections = generate_split_sections(M, N, edge_cases)
# currently disable pre-RHT amax
with_post_rht_amax = with_rht
if quantize_mode == "quantize":
return_identity = True
return_transpose = False
elif quantize_mode == "quantize_transpose":
return_identity = True
return_transpose = True
elif quantize_mode == "quantize_colwise_only":
return_identity = False
return_transpose = True
else:
raise ValueError(f"Invalid quantize mode: {quantize_mode}")
check_group_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
M=M,
N=N,
return_identity=return_identity,
return_transpose=return_transpose,
split_sections=split_sections,
with_rht=with_rht,
with_post_rht_amax=with_post_rht_amax,
with_random_sign_mask=with_random_sign_mask,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import List, Tuple
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
......@@ -151,6 +156,74 @@ def quantize_fp4(
return qx, sx, qx_t, sx_t
def group_quantize_fp4(
x: torch.Tensor,
use_stochastic_rounding: bool,
use_2D: bool,
use_RHT: bool,
split_sections: list[int],
use_tex_split_quantize: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Group quantize function with toggle between tex.split_quantize and manual split/call methods.
Args:
x (torch.Tensor): Input tensor.
use_stochastic_rounding (bool): Use stochastic rounding.
use_2D (bool): Use 2D quantization.
use_RHT (bool): Use RHT.
split_sections (list[int]): Split sizes for inputs.
use_tex_split_quantize (bool): Toggle method. If True, use tex.split_quantize, else use manual split and per-quantizer invocation.
Returns:
tuple: Lists of quantized tensors and scale tensors for all sections.
"""
num_tensors = len(split_sections)
nvfp4_quantizers = [
NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=use_RHT,
with_post_rht_amax=True,
stochastic_rounding=use_stochastic_rounding,
with_2d_quantization=use_2D,
)
for _ in range(num_tensors)
]
if use_tex_split_quantize:
outputs = tex.split_quantize(x, split_sections, nvfp4_quantizers)
qx_list = [output._rowwise_data.view(dtype=torch.uint8) for output in outputs]
sx_list = [output._rowwise_scale_inv for output in outputs]
qx_t_list = [output._columnwise_data.view(dtype=torch.uint8) for output in outputs]
sx_t_list = [output._columnwise_scale_inv for output in outputs]
else:
x_chunks = torch.split(x, split_sections)
qx_list = []
sx_list = []
qx_t_list = []
sx_t_list = []
for i in range(num_tensors):
x_chunk = x_chunks[i]
x_nvfp4_sut = nvfp4_quantizers[i](x_chunk)
assert x_nvfp4_sut._rowwise_data is not None
qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._rowwise_scale_inv is not None
sx = x_nvfp4_sut._rowwise_scale_inv
assert x_nvfp4_sut._columnwise_data is not None
qx_t = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
assert x_nvfp4_sut._columnwise_scale_inv is not None
sx_t = x_nvfp4_sut._columnwise_scale_inv
qx_list.append(qx)
sx_list.append(sx)
qx_t_list.append(qx_t)
sx_t_list.append(sx_t)
return qx_list, sx_list, qx_t_list, sx_t_list
def check_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool
) -> None:
......@@ -209,6 +282,92 @@ def check_quantization_nvfp4_versus_reference(
assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest."
def check_group_quantization_nvfp4_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
use_2D: bool,
use_RHT: bool,
num_splits: int,
use_tex_split_quantize: bool = True,
) -> None:
device = "cuda"
torch.manual_seed(seed)
n_iters = 50
split_sections = [M // num_splits] * num_splits
x_total = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1
x_splits = torch.split(x_total, split_sections)
q_rn_list, s_rn_list, q_t_rn_list, s_t_rn_list = group_quantize_fp4(
x_total,
use_stochastic_rounding=False,
use_2D=use_2D,
use_RHT=use_RHT,
split_sections=split_sections,
use_tex_split_quantize=use_tex_split_quantize,
)
sr_n_iter_results = []
for i in range(n_iters):
q_sr_list, s_sr_list, q_t_sr_list, s_t_sr_list = group_quantize_fp4(
x_total,
use_stochastic_rounding=True,
use_2D=use_2D,
use_RHT=use_RHT,
split_sections=split_sections,
use_tex_split_quantize=use_tex_split_quantize,
)
sr_n_iter_results.append((q_sr_list, s_sr_list, q_t_sr_list, s_t_sr_list))
for i, x in enumerate(x_splits):
y = x.t().contiguous()
if use_RHT:
y = RHT(y)
amax = torch.max(torch.abs(x)).float()
# fetch q_rn, s_rn, q_t_rn, s_t_rn
q_rn = q_rn_list[i]
s_rn = s_rn_list[i]
q_t_rn = q_t_rn_list[i]
s_t_rn = s_t_rn_list[i]
dq_rn = dequantize_fp4(q_rn, s_rn, amax)
dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax)
error_rn = (dq_rn - x).float()
me_rn = torch.sqrt((error_rn * error_rn).mean())
error_t_rn = (dq_t_rn - y).float()
me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean())
sr_result = torch.zeros_like(x).float()
sr_t_result = torch.zeros_like(x).float().t().contiguous()
for iter_idx in range(n_iters):
result_sr = sr_n_iter_results[iter_idx]
q_sr = result_sr[0][i]
s_sr = result_sr[1][i]
q_t_sr = result_sr[2][i]
s_t_sr = result_sr[3][i]
dq_sr = dequantize_fp4(q_sr, s_sr, amax)
dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax)
sr_result += dq_sr.float()
sr_t_result += dq_t_sr.float()
# Get the mean result of the stochastic rounding
# It should be more accurate than the RN result
sr_result /= n_iters
error_sr = (sr_result - x).float()
me_sr = torch.sqrt((error_sr * error_sr).mean())
sr_t_result /= n_iters
error_t_sr = (sr_t_result - y).float()
me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest."
assert (
me_t_sr < me_t_rn
), "Stochastic rounding failed - error larger than the round to nearest."
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
......@@ -236,3 +395,39 @@ def test_quantization_block_tiling_versus_reference(
M=M,
N=N,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
(8192, 8192),
(4096, 7168),
(16384, 2048),
],
)
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("use_2D", [False], ids=str)
@pytest.mark.parametrize("use_RHT", [True], ids=str)
@pytest.mark.parametrize("num_splits", [4, 8], ids=str)
@pytest.mark.parametrize("use_tex_split_quantize", [True, False], ids=str)
def test_group_stochastic_rounding_quantization_versus_reference(
x_dtype: torch.dtype,
use_2D: bool,
use_RHT: bool,
num_splits: int,
use_tex_split_quantize: bool,
M: int,
N: int,
) -> None:
if x_dtype == torch.float32 and use_RHT:
pytest.skip("RHT is only supported with bfloat16")
check_group_quantization_nvfp4_versus_reference(
x_dtype=x_dtype,
use_2D=use_2D,
use_RHT=use_RHT,
M=M,
N=N,
num_splits=num_splits,
use_tex_split_quantize=use_tex_split_quantize,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -54,9 +54,13 @@ gc.disable()
class Utils:
# Tensor used for simulating long-running GPU work in long_job()
tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16)
_B = 64
_S = 256
# Test tensor dimensions: _B x _S x _D = 128 x 512 x 256 = 16,777,216 elements
# This exceeds the 256K element threshold for offloading (cpu_offload.py line 443).
# For quantized tensors, scale_inv tensors (~524K elements for block scaling) also exceed threshold.
_B = 128
_S = 512
_H = 4
_D = 256
......@@ -395,6 +399,9 @@ class TestsDefaultOffloadSynchronizer:
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1)
# Verify x1 is not corrupted after pushing (important for QuantizedTensor)
if recipe is not None:
x1.dequantize() # Should not raise - tensor should still be valid
offload_synchronizer.fwd_step()
# Only one copy of tensor on cpu is allocated.
assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -203,7 +203,8 @@ _test_cuda_graphs_modules: List[str] = [
# creating TMA descriptor for MXFP8 quantization.
"linear",
"transformer",
"layernorm_mlp",
"layernorm_mlp_nocheckpoint",
"layernorm_mlp_checkpoint",
"layernorm_linear",
"mha",
"linear_op",
......@@ -245,12 +246,23 @@ def _test_cuda_graphs(
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp":
elif module == "layernorm_mlp_nocheckpoint":
modules = [
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
checkpoint=False,
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp_checkpoint":
modules = [
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
checkpoint=True,
)
for _ in range(num_layers)
]
......@@ -389,6 +401,17 @@ def test_make_graphed_callables(
)
if fp8_params:
pytest.skip("NVFP4 params not supported")
if (
fp8
and fp8_recipe.delayed()
and torch.cuda.get_device_capability() >= (10, 0)
and module == "layernorm_mlp_checkpoint"
):
pytest.skip(
"CUDA graphs not supported for LayerNormMLP "
"with checkpoint=True, SM>=10, "
"and DelayedScaling recipe"
)
if fp8 and not fp8_available:
pytest.skip("FP8 not supported on rocm GPU.")
if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
......@@ -421,7 +444,8 @@ def test_make_graphed_callables(
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
"transformer",
"layernorm_mlp",
"layernorm_mlp_nocheckpoint",
"layernorm_mlp_checkpoint",
"layernorm_linear",
"linear",
"mha",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -28,7 +28,6 @@ dtype = torch.bfloat16
class TestDeferredInit:
@staticmethod
def get_module_args(module):
hidden_size = num_heads * head_dim
......@@ -82,3 +81,45 @@ class TestDeferredInit:
"on CUDA device"
)
del module
@pytest.mark.parametrize("module_type", _core_modules)
def test_reset_parameters_doesnt_change_parameter_stats(
self,
module_type: torch.nn.Module,
) -> None:
"""Test for github issue #2528 and #2529 to ensure that reset_parameters() doesn't change
the parameter mean and std"""
args, kwargs = TestDeferredInit.get_module_args(module_type)
kwargs["device"] = "cuda"
module = module_type(*args, **kwargs)
param_stats = {
name: {"mean": param.mean(), "std": param.std()}
for name, param in module.named_parameters()
}
with torch.no_grad():
module.reset_parameters()
param_stats_after = {
name: {"mean": param.mean(), "std": param.std()}
for name, param in module.named_parameters()
}
for name, stats in param_stats_after.items():
torch.testing.assert_close(
stats["mean"],
param_stats[name]["mean"],
atol=1e-3,
rtol=1e-3,
msg=f"{name} mean changed after reset_parameters",
)
torch.testing.assert_close(
stats["std"],
param_stats[name]["std"],
atol=1e-3,
rtol=1e-3,
msg=f"{name} std changed after reset_parameters",
)
del module
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment