# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. from __future__ import annotations import logging import os from contextlib import contextmanager from typing import Optional, Tuple, Dict, Any, List from packaging.version import Version as PkgVersion import torch import transformer_engine import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import InferenceParams, QuantizedTensor from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( get_attention_backend, AttentionParams, AttentionLogging, check_set_window_size, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: """Convert type name to PyTorch dtype""" if isinstance(dtype, torch.dtype): return dtype name = str(dtype).strip().lower() if name.startswith("torch."): name = name.replace("torch.", "", 1) if name.startswith("fp"): name = name.replace("fp", "float", 1) dtype = dict( float32=torch.float32, float=torch.float32, float64=torch.float64, double=torch.float64, float16=torch.float16, half=torch.float16, bfloat16=torch.bfloat16, bf16=torch.bfloat16, float8_e4m3fn=torch.float8_e4m3fn, float8_e4m3=torch.float8_e4m3fn, float8e4m3=torch.float8_e4m3fn, float8=torch.float8_e4m3fn, float8_e5m2=torch.float8_e5m2, float8e5m2=torch.float8_e5m2, uint8=torch.uint8, byte=torch.uint8, int8=torch.int8, char=torch.int8, int16=torch.int16, short=torch.int16, int32=torch.int32, int=torch.int32, int64=torch.int64, long=torch.int64, bool=torch.bool, )[name] return dtype def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: """Estimated numerical error for a datatype Based on tolerances for torch.testing.assert_close. """ # Transformer Engine dtypes if isinstance(dtype, tex.DType): if dtype == tex.DType.kFloat4E2M1: return dict(rtol=0.25, atol=0.125) # epsilon = 0.25 dtype = { tex.DType.kByte: torch.uint8, tex.DType.kInt32: torch.int32, tex.DType.kFloat32: torch.float32, tex.DType.kFloat16: torch.half, tex.DType.kBFloat16: torch.bfloat16, tex.DType.kFloat8E4M3: torch.float8_e4m3fn, tex.DType.kFloat8E5M2: torch.float8_e5m2, }[dtype] # PyTorch dtypes if dtype == torch.float16: return dict(rtol=1e-3, atol=1e-5) if dtype == torch.bfloat16: return dict(rtol=1.6e-2, atol=1e-5) if dtype == torch.float32: return dict(rtol=1.3e-6, atol=1e-5) if dtype == torch.float64: return dict(rtol=1e-7, atol=1e-7) if dtype == torch.float8_e4m3fn: return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 if dtype == torch.float8_e5m2: return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 raise ValueError(f"Unsupported dtype ({dtype})") def quantization_tols(name: str) -> dict[str, float]: """Estimated numerical error for a quantization scheme""" if name in ( "fp8", "fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) if name == "nvfp4": return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") def make_recipe(name: Optional[str]) -> Optional[Recipe]: """Make recipe for quantization scheme""" if name is None: return None if name in ("fp8", "fp8_delayed_scaling"): return transformer_engine.common.recipe.DelayedScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, amax_history_len=8, ) if name == "fp8_current_scaling": return transformer_engine.common.recipe.Float8CurrentScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, ) if name == "mxfp8": return transformer_engine.common.recipe.MXFP8BlockScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() if name == "nvfp4": return transformer_engine.common.recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, ) raise ValueError(f"Unsupported quantization scheme ({name})") # Cached RNG state _rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None def reset_rng_states() -> None: """Revert to deterministic RNG state""" global _rng_states if _rng_states is None: torch.manual_seed(1234) torch.cuda.manual_seed(1234) _rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state()) else: cpu_rng_state, cuda_rng_state = _rng_states torch.set_rng_state(cpu_rng_state) torch.cuda.set_rng_state(cuda_rng_state) def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): if not is_fp8: torch.testing.assert_close(a, b, atol=atol, rtol=rtol) return try: if a.dtype != b.dtype: a = a.to(b.dtype) torch.testing.assert_close(a, b, atol=atol, rtol=rtol) except Exception as e: logging.debug(e) rmse = torch.sqrt((a - b).square().mean()).item() logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) assert rmse < rmse_tol * rmse_range, ( name_a + " vs " + name_b + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( rmse, rmse_tol * rmse_range, rmse_tol, rmse_range ) ) class ModelConfig: def __init__( self, batch_size: int, max_seqlen_q: int, num_heads: int, head_dim_qk: int, max_seqlen_kv: int = None, num_gqa_groups: int = None, head_dim_v: int = None, softmax_type: str = "vanilla", dropout_p: float = 0.0, attn_mask_type: str = "no_mask", attn_bias_type: str = "no_bias", alibi_type: str = "none", bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), context_parallel: bool = False, cp_comm_type: str = "p2p", return_max_logit=False, total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, eps: float = 1e-5, num_splits=1, ): self.batch_size = batch_size self.max_seqlen_q = max_seqlen_q self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv self.num_heads = num_heads self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups self.head_dim_qk = head_dim_qk self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v if self.head_dim_qk == self.head_dim_v: self.kv_channels = self.head_dim_qk else: self.kv_channels = (self.head_dim_qk, self.head_dim_v) self.hidden_size = self.num_heads * self.head_dim_qk self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v self.softmax_type = softmax_type self.dropout_p = dropout_p self.attn_mask_type = attn_mask_type self.attn_bias_type = attn_bias_type self.alibi_type = alibi_type self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.bias_shape = bias_shape self.window_size = check_set_window_size(self.attn_mask_type, window_size) self.context_parallel = context_parallel self.cp_comm_type = cp_comm_type self.return_max_logit = return_max_logit self.total_requests = total_requests self.max_ctx_len = max_ctx_len self.num_layers = num_layers self.eps = eps self.num_splits = num_splits @contextmanager def logging_context(highest_level=logging.WARNING): previous_level = logging.root.manager.disable logging.disable(highest_level) try: yield finally: logging.disable(previous_level) def get_available_attention_backends( config: ModelConfig, qkv_dtype: torch.dtype, qkv_layout: str, pad_between_seqs: bool = False, deterministic: bool = False, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, is_training: bool = True, inference_params: Optional[InferenceParams] = None, ) -> Tuple[List, List]: """Check for all available attention backends that support a model configuration""" os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True alibi_slopes_shape = None if config.attn_bias_type == "alibi" and config.alibi_type == "custom": if config.bias_shape == "1hss": alibi_slopes_shape = [config.num_heads] if config.bias_shape == "bhss": alibi_slopes_shape = [config.batch_size, config.num_heads] core_attention_bias_shape = ( config.bias_shape if config.attn_bias_type == "post_scale_bias" else None ) core_attention_bias_requires_grad = False # d=256 is supported by cuDNN 9.0+ for inference but not training if ( config.attn_bias_type == "post_scale_bias" and config.head_dim_qk <= 128 and config.head_dim_v <= 128 ): core_attention_bias_requires_grad = True fused_attn_backends = [] available_backends = None flash_attention_backend = None fused_attention_backend = None def test(): attention_params = AttentionParams( qkv_dtype=qkv_dtype, qkv_layout=qkv_layout, batch_size=config.batch_size, num_heads=config.num_heads, num_gqa_groups=config.num_gqa_groups, max_seqlen_q=config.max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, head_dim_qk=config.head_dim_qk, head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, window_size=config.window_size, alibi_slopes_shape=alibi_slopes_shape, core_attention_bias_type=config.attn_bias_type, core_attention_bias_shape=core_attention_bias_shape, core_attention_bias_requires_grad=core_attention_bias_requires_grad, pad_between_seqs=pad_between_seqs, attention_dropout=config.dropout_p, context_parallel=config.context_parallel, cp_comm_type=config.cp_comm_type, deterministic=deterministic, fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, inference_params=inference_params, softmax_type=config.softmax_type, return_max_logit=config.return_max_logit, # allow all backends to pass so they can be used for testing; # check for FA3 availability later num_splits=1, ) ( use_flash_attention, flash_attention_backend, use_fused_attention, fused_attention_backend, use_unfused_attention, available_backends, ) = get_attention_backend(attention_params) # Check if FA3 is an available backend when num_splits != 1 if available_backends[0]: if config.num_splits != 1 and not flash_attention_backend > PkgVersion("3.0.0b"): available_backends[0] = False # Set attention.py _attention_backends var using return value # from get_attention_backend() _attention_backends["use_flash_attention"] = use_flash_attention _attention_backends["use_fused_attention"] = use_fused_attention _attention_backends["flash_attention_backend"] = flash_attention_backend _attention_backends["fused_attention_backend"] = fused_attention_backend _attention_backends["use_unfused_attention"] = use_unfused_attention _attention_backends["backend_selection_requires_update"] = False return available_backends, flash_attention_backend, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() for i in range(3): os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) _attention_backends["backend_selection_requires_update"] = True available_backends, flash_attention_backend, fused_attention_backend = test() if fused_attention_backend == FusedAttnBackend[backends[i]]: fused_attn_backends.append(fused_attention_backend) return available_backends, flash_attention_backend, fused_attn_backends @torch.no_grad def assert_close( actual: Optional[torch.Tensor], expected: Optional[torch.Tensor], *, check_device: bool = False, check_dtype: bool = False, check_layout: bool = False, **kwargs, ) -> None: """Assert that two tensors are close. This function is a wrapper around torch.testing.assert_close. It changes the defaults for device and dtype checks (useful when the reference implementation is computed in high precision on CPU) and it can handle quantized tensors. """ if isinstance(actual, QuantizedTensor): actual = actual.dequantize() if isinstance(expected, QuantizedTensor): expected = expected.dequantize() torch.testing.assert_close( actual, expected, check_device=check_device, check_dtype=check_dtype, check_layout=check_layout, **kwargs, ) def assert_close_grads( actual: Optional[torch.Tensor], expected: Optional[torch.Tensor], **kwargs, ) -> None: """Assert that two tensors have close gradients.""" if actual is None and expected is None: return assert actual is not None assert expected is not None assert_close(actual.grad, expected.grad, **kwargs)