# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. from collections import OrderedDict import math import os from typing import Dict, List, Tuple, Optional import pytest import copy import random import torch import torch.nn as nn from torch.nn import Parameter from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, fp8_autocast, fp8_model_init, ) from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, attention_mask_func, is_bf16_compatible, ) from transformer_engine.pytorch import ( DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, GroupedLinear, BatchedLinear, MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, Fp8Padding, Fp8Unpadding, ) from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.common import recipe import transformer_engine_torch as tex # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() sm_80plus = get_device_compute_capability() >= (8, 0) seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) # Record initial RNG state from script run. _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() if torch_version() >= (2, 7, 0): torch._dynamo.config.recompile_limit = 16 else: torch._dynamo.config.cache_size_limit = 16 class ModelConfig: def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): self.hidden_size = hidden_size self.eps = eps self.num_attention_heads = num_attention_heads self.embed = embed self.num_layers = num_layers self.seq_len = seq_len model_configs = { "small": ModelConfig(128, 1e-5, 8, 36, 4, 128), "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), } model_configs_inference = { # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256), } backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"] input_formats_inference = ["sbhd", "bshd"] param_types = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) batch_sizes = [1, 2] all_boolean = [True, False] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] all_normalizations = ["LayerNorm", "RMSNorm"] mask_types = ["causal", "no_mask"] fp8_recipes = [ recipe.MXFP8BlockScaling(), recipe.DelayedScaling(), recipe.Float8CurrentScaling(), ] def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: """Estimated numerical error for a datatype Based on tolerances for torch.testing.assert_close. """ if dtype == torch.float32: return dict(rtol=1.3e-6, atol=1e-5) if dtype == torch.float16: return dict(rtol=1e-3, atol=1e-5) if dtype == torch.bfloat16: return dict(rtol=1.6e-2, atol=1e-5) raise ValueError(f"Unsuppored dtype ({dtype})") def assert_allclose( l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None ) -> bool: """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." for i, (t1, t2) in enumerate(zip(l1, l2)): tols = dict(atol=atol) if rtol is not None: tols["rtol"] = rtol result = torch.allclose(t1, t2, **tols) if not result: diff = torch.abs(t1 - t2) tol = atol + (rtol * torch.abs(t2)) exceed_mask = diff > tol if exceed_mask.any(): indices = torch.nonzero(exceed_mask, as_tuple=True) max_diff = diff[exceed_mask].max() max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] max_location = [idx[max_idx].item() for idx in indices] msg = ( f"Outputs not close enough in tensor at idx={i}. " f"Maximum difference at location {max_location} " f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " f"(diff {max_diff.item()})." ) raise AssertionError(msg) def reset_rng_states() -> None: """revert back to initial RNG state.""" torch.set_rng_state(_cpu_rng_state) torch.cuda.set_rng_state(_cuda_rng_state) @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield FP8GlobalStateManager.reset() def _test_batched_linear_accuracy( block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ): reset_rng_states() if fp8: FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( (config.seq_len, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() assert config.seq_len % num_gemms == 0 m_splits = torch.tensor([config.seq_len // num_gemms for i in range(num_gemms)]) assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, BatchedLinear): m_splits = m_splits * bs out = block(inp_hidden_states, m_splits.tolist()) else: out = torch.cat( [ block[i](inp) for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist())) ] ) loss = out.sum() loss.backward() torch.cuda.synchronize() outputs = [out, inp_hidden_states.grad] return outputs @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("num_gemms", [4, 8]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) def test_batched_linear_accuracy( dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, fuse_wgrad_accumulation, parallel_mode=None, ): batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2")) if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches pytest.skip("MXFP8 unsupported for batched linear.") if fp8 and recipe.float8_current_scaling(): pytest.skip("Float8 Current Scaling unsupported for batched linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): batched_linear = BatchedLinear( num_gemms, config.hidden_size, 4 * config.hidden_size, bias=False, params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() sequential_linear = torch.nn.ModuleList( [ Linear( config.hidden_size, 4 * config.hidden_size, bias=False, params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() for _ in range(num_gemms) ] ) # Share params with torch.no_grad(): for i in range(num_gemms // batch_num): weight = getattr(batched_linear, f"weight{i}").clone() # bias = getattr(batched_linear, f"bias{i}").clone() if fuse_wgrad_accumulation: weight_i = getattr(batched_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) for j in range(batch_num): sequential_linear[i * batch_num + j].weight = Parameter(weight[weight.shape[0] // batch_num * j : weight.shape[0] // batch_num * (j + 1)].clone()) # sequential_linear[i * batch_num + j].bias = Parameter(bias[bias.shape[0] // batch_num * j : bias.shape[0] // batch_num * (j + 1)].clone()) if fuse_wgrad_accumulation: sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone() outputs_ref = _test_batched_linear_accuracy( sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ) outputs = _test_batched_linear_accuracy( batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ) # Shoule be bit-wise match for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3) if __name__ == "__main__": test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True)