Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): ...@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
@contextmanager @contextmanager
def use_jax_gemm(enabled=False): def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None)
try: try:
if enabled: if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false"
else:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true"
yield yield
finally: finally:
if enabled: if enabled:
if orig_custom_calls_filter is None: if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") os.environ.pop("NVTE_JAX_CUSTOM_CALLS")
else: else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter
...@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel ...@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
get_cu_seqlens_on_cp_rank, get_cu_seqlens_on_cp_rank,
) )
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
import logging import logging
import math import math
import os import os
from typing import Any, Dict, List, Tuple, Union, Optional import sys
from contextlib import contextmanager import pathlib
from typing import Any, Dict, Tuple, Union
import pytest import pytest
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
...@@ -20,11 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import ( ...@@ -20,11 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import (
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils, FlashAttentionUtils,
get_attention_backend,
check_set_window_size, check_set_window_size,
AttentionParams,
) )
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
...@@ -49,21 +46,21 @@ from transformer_engine.pytorch.tensor.quantized_tensor import ( ...@@ -49,21 +46,21 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
reset_rng_states,
ModelConfig,
dtype_tols,
get_available_attention_backends,
)
# Only run FP8 tests on H100 # Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
# Initialize RNG state
seed = 1234 seed = 1234
torch.manual_seed(seed) # Reset RNG states
torch.cuda.manual_seed(seed) reset_rng_states()
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -72,170 +69,20 @@ def reset_global_fp8_state(): ...@@ -72,170 +69,20 @@ def reset_global_fp8_state():
fp8.FP8GlobalStateManager.reset() fp8.FP8GlobalStateManager.reset()
class ModelConfig:
def __init__(
self,
batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim_qk: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
head_dim_v: int = None,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
self.hidden_size = num_heads * head_dim_qk
self.hidden_size_kv = num_gqa_groups * self.head_dim_v
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def _get_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
model_configs_base = { model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), "base_1_0": ModelConfig(8, 128, 16, 64),
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "base_2_0": ModelConfig(2, 2048, 24, 128),
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
} }
...@@ -279,7 +126,7 @@ def test_dot_product_attention( ...@@ -279,7 +126,7 @@ def test_dot_product_attention(
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
is_training = True is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -290,7 +137,7 @@ def test_dot_product_attention( ...@@ -290,7 +137,7 @@ def test_dot_product_attention(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported: if not fused_attn_supported:
is_training = False is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -411,62 +258,26 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -411,62 +258,26 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing""" """Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
if IS_HIP_EXTENSION:
model_configs_mla = { model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig( "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128 "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
), # self , 0 "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_1": ModelConfig( "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 "mla_2_1": ModelConfig(
), # cross, 0 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
"mla_1_2": ModelConfig( ), # cross, 1
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 "mla_2_2": ModelConfig(
), # cross, 0 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
"mla_2_0": ModelConfig( ), # cross, 1
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
), # self , 1 "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_2_1": ModelConfig( "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 }
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
}
else:
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
"mla_3_2": ModelConfig(
8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}
@pytest.mark.skipif(not IS_HIP_EXTENSION and get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla]) @pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys()) @pytest.mark.parametrize("model", model_configs_mla.keys())
...@@ -477,40 +288,46 @@ def test_dpa_mla(dtype, model_configs, model): ...@@ -477,40 +288,46 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask = { model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
"mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
"mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
"mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), "mask_2_1": ModelConfig(
"mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
"mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
), ),
"mask_2_2": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
),
"mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
"mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
"mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"mask_5_1": ModelConfig( "mask_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
), ),
"mask_5_2": ModelConfig( "mask_5_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
"mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
"mask_7_0": ModelConfig(
2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
),
"mask_7_1": ModelConfig(
2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
), ),
"mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
"mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
"mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"),
"mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"),
"mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_10_0": ModelConfig( "mask_10_0": ModelConfig(
2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" 2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
), ),
"mask_10_1": ModelConfig( "mask_10_1": ModelConfig(
2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" 2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
), ),
} }
...@@ -526,44 +343,102 @@ def test_dpa_mask(dtype, model_configs, model): ...@@ -526,44 +343,102 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = { model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped "bias_1_5": ModelConfig(
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped ), # skipped
"bias_2_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
"bias_2_1": ModelConfig(
2,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
"bias_2_2": ModelConfig( "bias_2_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias" 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped ), # skipped
"bias_2_3": ModelConfig( "bias_2_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias" 2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
"bias_2_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
"bias_2_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped ), # skipped
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped "bias_3_0": ModelConfig(
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), ),
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"), "bias_3_1": ModelConfig(
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), 2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"bias_3_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"bias_3_3": ModelConfig( "bias_3_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias" 2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # skipped
"bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
"bias_3_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
), # skipped ), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig( "bias_4_0": ModelConfig(
4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias" 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped ), # skipped
"bias_4_1": ModelConfig( "bias_4_1": ModelConfig(
2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias" 2,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped ), # skipped
"bias_4_2": ModelConfig( "bias_4_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias" 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped ), # skipped
"bias_4_3": ModelConfig( "bias_4_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias" 2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
"bias_4_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
), # skipped
"bias_4_5": ModelConfig(
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="alibi",
), # skipped ), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
} }
...@@ -578,33 +453,29 @@ def test_dpa_bias(dtype, model_configs, model): ...@@ -578,33 +453,29 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = { model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p, # test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig( "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
"bias_1_4": ModelConfig(
4, 4,
16, 2048,
16, 24,
64,
128,
128, 128,
0.0, attn_mask_type="causal",
# mask, bias, bias_shape, attn_bias_type="alibi",
"no_mask", bias_shape="1hss",
"post_scale_bias", alibi_type="custom",
bias_shape="11ss",
),
"bias_1_1": ModelConfig(
2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss"
),
"bias_1_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss"
),
"bias_1_3": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss"
),
"bias_1_4": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom"
), ),
"bias_1_5": ModelConfig( "bias_1_5": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom" 2,
2048,
24,
128,
attn_mask_type="causal",
attn_bias_type="alibi",
bias_shape="bhss",
alibi_type="custom",
), ),
} }
...@@ -620,34 +491,36 @@ def test_dpa_bias_shapes(dtype, model_configs, model): ...@@ -620,34 +491,36 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = { model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_1_1": ModelConfig(2, 2048, 16, 64),
"swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
"swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"), "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
"swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), "swa_3_2": ModelConfig(
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
"swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), ),
"swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"), "swa_3_3": ModelConfig(
"swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
"swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_1": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
), ),
"swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
"swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
"swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
"swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"swa_6_2": ModelConfig( "swa_6_2": ModelConfig(
2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
), ),
"swa_6_3": ModelConfig( "swa_6_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
), ),
} }
@pytest.mark.skipif((not IS_HIP_EXTENSION) and (not FlashAttentionUtils.v2_3_plus), reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys()) @pytest.mark.parametrize("model", model_configs_swa.keys())
...@@ -658,18 +531,36 @@ def test_dpa_sliding_window(dtype, model_configs, model): ...@@ -658,18 +531,36 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = { model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"), "alibi_1_0": ModelConfig(
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"), 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
),
"alibi_1_1": ModelConfig(
1,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="causal",
attn_bias_type="alibi",
alibi_type="vanilla",
),
"alibi_2_0": ModelConfig( "alibi_2_0": ModelConfig(
2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom" 2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
), ),
"alibi_2_1": ModelConfig( "alibi_2_1": ModelConfig(
1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom" 1,
1024,
24,
128,
max_seqlen_kv=2048,
attn_mask_type="causal",
attn_bias_type="alibi",
alibi_type="custom",
), ),
} }
@pytest.mark.skipif((not IS_HIP_EXTENSION) and (not FlashAttentionUtils.v2_3_plus), reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys()) @pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
...@@ -694,16 +585,38 @@ qkv_layouts = [ ...@@ -694,16 +585,38 @@ qkv_layouts = [
model_configs_layout = { model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), "layout_0_0": ModelConfig(2, 128, 16, 64),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), "layout_0_1": ModelConfig(
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), ),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), "layout_0_3": ModelConfig(
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), 1,
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), 128,
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), 16,
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"), 64,
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
),
"layout_1_0": ModelConfig(2, 2048, 24, 128),
"layout_1_1": ModelConfig(
2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"layout_1_3": ModelConfig(
1,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
),
"layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
"layout_2_1": ModelConfig(
2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
} }
...@@ -720,55 +633,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): ...@@ -720,55 +633,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = { model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
"layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), "layout_1_2": ModelConfig(
"layout_2_0": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
), ),
"layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"layout_2_1": ModelConfig( "layout_2_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
), ),
"layout_2_2": ModelConfig( "layout_2_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"layout_3_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
), ),
"layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
"layout_3_1": ModelConfig( "layout_3_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
), ),
"layout_3_2": ModelConfig( "layout_3_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
),
"layout_4_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
), ),
"layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
"layout_4_1": ModelConfig( "layout_4_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
), ),
"layout_4_2": ModelConfig( "layout_4_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
), ),
"layout_5_0": ModelConfig( "layout_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) 2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
), ),
"layout_5_1": ModelConfig( "layout_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) 2,
2048,
24,
128,
num_gqa_groups=1,
attn_mask_type="padding_causal_bottom_right",
window_size=(4, 0),
), ),
"layout_5_2": ModelConfig( "layout_5_2": ModelConfig(
2, 2,
24, 2048,
24, 24,
128, 128,
2048, max_seqlen_kv=4096,
4096, attn_mask_type="padding_causal_bottom_right",
0.0,
"padding_causal_bottom_right",
"no_bias",
window_size=(4, 0), window_size=(4, 0),
), ),
} }
...@@ -1158,16 +1070,22 @@ def _run_dot_product_attention( ...@@ -1158,16 +1070,22 @@ def _run_dot_product_attention(
model_configs_te_layer = { model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), "te_1_1": ModelConfig(
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
"te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), ),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "te_1_2": ModelConfig(
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), 2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), ),
"te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"), "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), "te_2_1": ModelConfig(2, 2048, 16, 64),
"te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
"te_2_3": ModelConfig(
1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
"te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
} }
...@@ -1189,26 +1107,27 @@ def test_transformer_layer( ...@@ -1189,26 +1107,27 @@ def test_transformer_layer(
tols = dict(atol=5e-2, rtol=5e-2) tols = dict(atol=5e-2, rtol=5e-2)
workspace_opt = True workspace_opt = True
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd"
# override the qkv_layout in mqa gqa mode in ROCm TE
if IS_HIP_EXTENSION and model_configs[model].num_gqa_groups != model_configs[model].num_heads:
qkv_layout = "sbhd_sbhd_sbhd"
# Test backend availability # Test backend availability
is_training = True is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
is_training=is_training, is_training=is_training,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported: if not fused_attn_supported:
is_training = False is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=(
qkv_format.replace("hd", "h3d")
if fused_qkv_params
else qkv_format.replace("hd", "3hd")
),
is_training=is_training, is_training=is_training,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
...@@ -1514,20 +1433,164 @@ def _run_transformer_layer( ...@@ -1514,20 +1433,164 @@ def _run_transformer_layer(
return out, inp.grad return out, inp.grad
model_configs_fp8_extra_state = {
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs_fp8_extra_state[model]
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="sb3hd",
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported and not flash_attn_supported:
pytest.skip("No attention backend available.")
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_enabled,
fp8_mha=False,
)
reset_rng_states()
hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()
del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
model_configs_fp8_vs_f16 = { model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
"fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
"fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
"fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"), "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
"fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"), "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
"fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"), "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
"fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
"fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
"fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"), "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
} }
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
...@@ -1561,7 +1624,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): ...@@ -1561,7 +1624,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
) )
) )
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
...@@ -1576,18 +1639,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1576,18 +1639,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
9,
7,
0,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if ( # Test backend availability
FlashAttentionUtils.v3_is_installed available_backends, _, fused_attn_backends = get_available_attention_backends(
and not is_training config,
and "padding" not in config.attn_mask_type qkv_dtype=torch.float8_e4m3fn,
): qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
...@@ -1613,11 +1688,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1613,11 +1688,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.15 rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if ( if flash_attn_supported:
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
_error( _error(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -1768,7 +1839,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP ...@@ -1768,7 +1839,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
return out, param_names, tuple(x.grad for x in params) return out, param_names, tuple(x.grad for x in params)
return out, param_names, tuple(None for x in params) return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
...@@ -1790,23 +1861,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1790,23 +1861,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
# if get_device_compute_capability() >= (10, 0): # if get_device_compute_capability() >= (10, 0):
# config.dropout_p = 0.1 # config.dropout_p = 0.1
if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
9,
7,
0,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if ( # Test backend availability
FlashAttentionUtils.v3_is_installed available_backends, _, fused_attn_backends = get_available_attention_backends(
and not is_training config,
and "padding" not in config.attn_mask_type qkv_dtype=torch.float8_e4m3fn,
): qkv_layout=qkv_layout,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if flash_attn_supported + fused_attn_supported < 1:
pytest.skip("No FP8 attention backend available.")
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
...@@ -1835,11 +1917,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1835,11 +1917,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.11 rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"] bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
if ( if flash_attn_supported:
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
_error( _error(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -2013,21 +2091,21 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): ...@@ -2013,21 +2091,21 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
model_configs_fp8 = { model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_1": ModelConfig(1, 512, 1, 64),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_2": ModelConfig(4, 512, 16, 64),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_3": ModelConfig(1, 2048, 1, 128),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "fp8_4": ModelConfig(2, 2048, 24, 128),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"), "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"), "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
} }
param_types_fp8 = [torch.float16, torch.bfloat16] param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1")) cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"] models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif( @pytest.mark.skipif(
( (
get_cudnn_version() < (8, 9, 3) get_cudnn_version() < (8, 9, 3)
...@@ -2049,6 +2127,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -2049,6 +2127,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model] config = model_configs_fp8[model]
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not (fused_attn_backends and unfused_attn_supported):
pytest.skip("Not enough backends to run this test with.")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import os import os
import subprocess import subprocess
import sys
import pathlib
import pytest import pytest
import torch import torch
...@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import ( ...@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version, get_cudnn_version,
) )
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig( "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA
), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_1_3": ModelConfig( "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig( "cp_2_2": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA ), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
} }
...@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): ...@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node=" + str(num_gpus_per_node), "--nproc-per-node=" + str(num_gpus_per_node),
] ]
te_path = os.getenv("TE_PATH", "/opt/transformerengine") te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py")
args.append(script_path) args.append(script_path)
for k, v in kwargs.items(): for k, v in kwargs.items():
args.append(f"{k}={v}") args.append(f"{k}={v}")
...@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = { model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA "cp_1_2": ModelConfig(
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA ), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # GQA
"cp_2_3": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_4": ModelConfig( "cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA ), # GQA
"cp_3_0": ModelConfig( "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64 "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_2": ModelConfig( "cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
), # MLA ), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
} }
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="DTK not surpport fused attn for now, CP tests require sm80+.") @pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
...@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!") pytest.skip("MLA CP currently does not support FP8 attention!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
......
...@@ -5,18 +5,14 @@ ...@@ -5,18 +5,14 @@
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
import os import os
import sys
import pathlib
import logging import logging
import math import math
import pytest import pytest
import torch import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import ( ...@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible, is_bf16_compatible,
) )
# Initialize RNG state _current_file = pathlib.Path(__file__).resolve()
seed = 1234 sys.path.append(str(_current_file.parent.parent))
torch.manual_seed(seed) from utils import (
torch.cuda.manual_seed(seed) ModelConfig,
_cpu_rng_state = torch.get_rng_state() reset_rng_states,
_cuda_rng_state = torch.cuda.get_rng_state() get_available_attention_backends,
)
# Reset RNG states
reset_rng_states()
param_types = [torch.float16] param_types = [torch.float16]
if is_bf16_compatible(): if is_bf16_compatible():
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
model_configs_infer = { model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, sq, hq, dqk,
"infer_0": ModelConfig( "infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16),
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 "infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16),
),
"infer_1": ModelConfig(
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
} }
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
...@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged: if is_paged:
qkv_layout = "paged_kv_" + qkv_layout qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
...@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs): ...@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
set_weight_tensor_tp_group_reduce(True) # reset set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
kwargs["config_file"].write(LOG_QUANTIZED_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
_run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test @run_debug_test
def test_log_expert_parallel(**kwargs): def test_log_expert_parallel(**kwargs):
""" """
......
...@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs): ...@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs):
# FP8 enabled - true by the default # FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0 "decoder.1.attn.qkv", gemm="fprop", iteration=0
) )[0]
# modify_tensor_enabled - False by default # modify_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
) )[0]
# inspect_tensor_enabled - False by default # inspect_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.inspect_tensor_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0 "decoder.1.attn.qkv", tensor_name="activation", iteration=0
) )[0]
# inspect_tensor_postquantize - False by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs): ...@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0 "decoder.1.attn.qkv", gemm="fprop", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0 "decoder.1.attn.qkv", gemm="dgrad", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0 "decoder.1.attn.qkv", gemm="wgrad", iteration=0
) )[0]
# caching # caching
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0 "decoder.1.attn.qkv", gemm="fprop", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0 "decoder.1.attn.qkv", gemm="dgrad", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0 "decoder.1.attn.qkv", gemm="wgrad", iteration=0
) )[0]
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs): ...@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0 "decoder.1.mlp.fc1", gemm="fprop", iteration=0
) )[0]
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0 "decoder.1.mlp.fc1", gemm="wgrad", iteration=0
) )[0]
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0 "decoder.1.mlp.fc1", gemm="dgrad", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0 "decoder.1.attn.qkv", gemm="fprop", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0 "decoder.1.attn.qkv", gemm="wgrad", iteration=0
) )[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled( assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0 "decoder.1.attn.qkv", gemm="dgrad", iteration=0
) )[0]
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): ...@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
# check modify_tensor_enabled # check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
) )[0]
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0 "decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
) )[0]
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
) )[0]
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0 "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
) )[0]
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0 "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
) )[0]
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0 "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
) )[0]
# check modify_tensor # check modify_tensor
...@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): ...@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
gemm="wgrad", gemm="wgrad",
tensor_name="gradient", tensor_name="gradient",
iteration=0, iteration=0,
) )[0]
assert not debug_api.transformer_engine.modify_tensor_enabled( assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4", "decoder.1.mlp.fc4",
gemm="fprop", gemm="fprop",
tensor_name="activation", tensor_name="activation",
iteration=0, iteration=0,
) )[0]
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs): ...@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs):
# modify_tensor_enabled # modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
) )[0]
assert debug_api.transformer_engine.modify_tensor_enabled( assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
) )[0]
# modify_tensor # modify_tensor
debug_api.transformer_engine.modify_tensor( debug_api.transformer_engine.modify_tensor(
...@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs): ...@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0 "decoder.1.fc2", gemm="wgrad", iteration=0
) )[0]
# caching # caching
assert debug_api.transformer_engine.fp8_gemm_enabled( assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0 "decoder.1.fc2", gemm="wgrad", iteration=0
) )[0]
finally: finally:
debug_api.end_debug() debug_api.end_debug()
...@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs): ...@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs):
) )
tensor = torch.randn((100, 100, 5)).cuda() tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor( quantizer = Float8Quantizer(
data=tensor.to(torch.uint8).cuda(), scale=torch.full([1], 1.0).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(), amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
) )
tensor_fp8 = quantizer(tensor)
def log(): def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
...@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs): ...@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs):
tensor_name="activation", tensor_name="activation",
iteration=200, iteration=200,
tp_group=None, tp_group=None,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
) )
stats = log() stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max() assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201 "decoder.1.mlp.fc1", tensor_name="activation", iteration=201
) )[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200 "decoder.2.mlp.fc1", tensor_name="activation", iteration=200
) )[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 expected_underflows = (
((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
) )
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5) assert debug_api.transformer_engine.inspect_tensor_enabled(
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5) "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
# TE FP8 tensor stats -- # TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled( assert debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
) )[0]
debug_api.transformer_engine.inspect_tensor_postquantize( debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1", "decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient", tensor_name="gradient",
iteration=200, iteration=200,
rowwise=True,
tp_group=None, tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
) )
stats = log() stats = log()
torch.testing.assert_close( torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
) )
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201 "decoder.1.mlp.fc1", tensor_name="activation", iteration=201
) )[0]
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 "decoder.2.mlp.fc1", tensor_name="gradient", iteration=200
) )[0]
# Second config in same yaml # Second config in same yaml
tensor = torch.rand((100, 100, 5)) tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor( debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1", "decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation", tensor_name="activation",
iteration=200, iteration=200,
tp_group=None, tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
) )
stats = log() stats = log()
stats_names = [x[3] for x in stats.keys()] stats_names = [x[3] for x in stats.keys()]
...@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs): ...@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs):
debug_api.transformer_engine.inspect_tensor( debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1", "decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight", tensor_name="weight",
iteration=200, iteration=200,
tp_group=None, tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
) )
stats = log() stats = log()
stats_names = [x[3] for x in stats.keys()] stats_names = [x[3] for x in stats.keys()]
...@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs): ...@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert not debug_api.transformer_engine.inspect_tensor_enabled( assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201 "decoder.7.mlp.fc1", tensor_name="weight", iteration=201
) )[0]
assert_empty() assert_empty()
finally: finally:
...@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs): ...@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
default_logging_enabled=False, default_logging_enabled=False,
) )
def feed(tensor, tensor_fp8): def feed(tensor, tensor_fp8, quantizer):
debug_api.transformer_engine.inspect_tensor( debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1", "decoder.5.mlp.fc1",
tensor=tensor, tensor=tensor,
tensor_name="activation", tensor_name="activation",
iteration=1, iteration=1,
tp_group=None, tp_group=None,
) quantizer=quantizer,
debug_api.transformer_engine.inspect_tensor_postquantize( rowwise_quantized_tensor=tensor_fp8,
"decoder.5.mlp.fc1", columnwise_quantized_tensor=tensor_fp8,
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
) )
def log_stats(): def log_stats():
...@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs): ...@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return STATS_BUFFERS.log_stats() return STATS_BUFFERS.log_stats()
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
def fp8_tensor(t): def fp8_tensor(t):
return Float8Tensor( return quantizer(t.cuda())
data=t.to(torch.uint8).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
)
shape = [1024, 1024] shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)] tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0]) feed(tensors[0], tensors_fp8[0], quantizer)
feed(tensors[1], tensors_fp8[1]) feed(tensors[1], tensors_fp8[1], quantizer)
stats1 = log_stats() stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda() tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2) fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2) feed(tensor2, fp8tensor2, quantizer)
stats2 = log_stats() stats2 = log_stats()
assert len(stats1.keys()) > 0 assert len(stats1.keys()) > 0
......
test:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 1
freq: 3
LogFp8TensorStats:
enabled: True
tensors: activation
stats: [underflows%]
start_step: 1
freq: 5
\ No newline at end of file
test:
enabled: True
layers:
layer_name_regex_pattern: .*1
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 0
freq: 100000
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import torch
import tempfile
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import RecipeState
import pytest
import contextlib
import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.debug.pytorch.debug_state import TEDebugState
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
LOG_QUANTIZED_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
recipes = [
"fp8_delayed_scaling",
"fp8_current_scaling",
"fp8_block_scaling",
"mxfp8",
]
bare_stats = [
"underflows%",
"scale_inv_min",
"scale_inv_max",
"mse",
]
all_stats = []
for r in recipes:
for stat in bare_stats:
for columnwise_postfix in ["", "_columnwise"]:
if (
r in ["fp8_current_scaling", "fp8_block_scaling"]
and torch.cuda.get_device_capability()[0] < 9
):
# hopper is needed for current-scaling, block-scaling
continue
if r == "mxfp8" and torch.cuda.get_device_capability()[0] < 10:
# blackwell is needed for mxfp8
continue
if (
r in ["fp8_delayed_scaling", "fp8_current_scaling"]
and columnwise_postfix == "_columnwise"
):
# columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling
continue
all_stats.append(f"{r}_{stat}{columnwise_postfix}")
all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows%
@contextlib.contextmanager
def debug_session(config_str: str, feature_dirs):
"""
Helper context manager that
1. writes the YAML `config_str` to a temporary file,
2. starts a debug session, and
3. yields the directory that contains the statistics log.
The session is closed automatically – even on exceptions – so every test
stays concise and leak-free.
"""
with tempfile.NamedTemporaryFile(
mode="w", delete=False
) as cfg_file, tempfile.TemporaryDirectory() as log_dir:
cfg_file.write(config_str)
cfg_file.flush()
debug_api.initialize(
config_file=cfg_file.name,
feature_dirs=feature_dirs,
log_dir=log_dir,
)
try:
yield log_dir
finally:
debug_api.end_debug()
def read_log(log_dir: str) -> str:
"""Return the content of the statistics log produced by `debug_session`."""
stat_path = os.path.join(
log_dir,
"nvdlfw_inspect_statistics_logs",
"nvdlfw_inspect_globalrank-0.log",
)
with open(stat_path, "r") as f:
return f.read()
def test_sanity(feature_dirs):
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert output, "Output is empty"
for stat in all_stats:
assert stat in output, f"Stat {stat} not found in output"
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
pytest.skip(reason_for_no_mxfp8)
if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling():
pytest.skip(reason_for_no_fp8_block_scaling)
log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats))
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
recipe_state = RecipeState.create(
fp8_recipe,
mode="forward",
num_quantizers=3,
)
tensor = torch.zeros(1024, 1024).cuda()
tensor[0, :] = 1000
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
for line in output.splitlines():
if "underflows%" in line:
underflows = float(line.split("value=")[1])
expected = (
((dequantized_tensor == 0).sum() - (tensor == 0).sum())
/ dequantized_tensor.numel()
* 100
)
assert underflows == pytest.approx(expected.cpu(), abs=1e-4)
if "mse" in line:
mse = float(line.split("value=")[1])
expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean")
assert mse == pytest.approx(expected.cpu(), abs=1e-6)
if "overflows%" in line:
overflows = float(line.split("value=")[1])
expected = (
(abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100
)
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
# non-quantized statistics should be logged every 3 iterations,
# and quantized statistics should be logged every 5 iterations.
with tempfile.TemporaryDirectory() as temp_dir:
debug_api.initialize(
config_file=configs_dir + "/log_config.yaml",
feature_dirs=feature_dirs,
log_dir=temp_dir,
)
if layer == "linear":
model = te.Linear(128, 128, name="linear1")
elif layer == "transformer":
model = te.TransformerLayer(128, 128, 4, name="transformer1")
else:
raise ValueError(f"Invalid layer: {layer}")
for i in range(20):
x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True):
y = model(x)
y.sum().backward()
debug_api.step()
with open(
os.path.join(
temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
),
"r",
) as f:
file_content = f.read()
for i in range(1, 20):
if i % 3 == 0 or i % 5 == 0:
assert f"iteration={i:06d}" in file_content
else:
assert f"iteration={i:06d}" not in file_content
debug_api.end_debug()
TEDebugState._reset()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import time
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs):
debug_api.end_debug()
TEDebugState._reset()
if debug_tools_initialized:
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api.initialize(
config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs
)
try:
if layer == "linear":
model = torch.nn.Sequential(
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2")
).cuda()
NUM_ITERS = 18000
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 2000
x = torch.randn(1, 1, 1).cuda()
y = model(x)
y.sum().backward()
debug_api.step()
torch.cuda.synchronize()
time_start = time.time()
for i in range(NUM_ITERS):
y = model(x)
y.sum().backward()
if debug_tools_initialized:
debug_api.step()
torch.cuda.synchronize()
time_end = time.time()
finally:
if debug_tools_initialized:
debug_api.end_debug()
return time_end - time_start
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_cpu_overhead(layer, configs_dir, feature_dirs):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs)
without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs)
print(f"with_debug_tools: {with_debug_tools} s")
print(f"without_debug_tools: {without_debug_tools} s")
assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin
...@@ -519,6 +519,7 @@ def _train(opts): ...@@ -519,6 +519,7 @@ def _train(opts):
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
del test_graph del test_graph
torch.cuda.synchronize()
te.module.base.destroy_ub() te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True) dist_print("Destroying Userbuffers objects...", debug=True)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib
import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch import TransformerLayer, Linear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig
model_configs = {
"small": ModelConfig(2, 10, 2, 16),
}
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"])
def test_current_device(model, module):
"""Test cases where current device is different from tensor device"""
num_devices = torch.cuda.device_count()
assert num_devices > 1, "This test requires more than one GPU!"
tensor_device = num_devices - 1
dtype = torch.bfloat16
config = model_configs[model]
args = []
kwargs = {}
bwd_args = []
if module == "TransformerLayer":
model = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
params_dtype=dtype,
attn_input_format="thd",
self_attn_mask_type="padding",
device=f"cuda:{tensor_device}",
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
(num_tokens, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention":
model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
num_tokens,
config.num_heads,
config.head_dim_qk,
dtype=dtype,
device=tensor_device,
requires_grad=True,
)
for _ in range(3)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)]
elif module == "Linear":
model = Linear(
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs)
if module == "DotProductAttention":
out.backward(*bwd_args)
else:
loss = out.sum()
loss.backward()
current_device_after = torch.cuda.current_device()
tensor_device_out = out.get_device()
tensor_device_grad = args[0].grad.get_device()
assert (
current_device_after == current_device_before
), "The current device should not have changed!"
assert (
tensor_device_out == tensor_device
), "The output tensor should be the same as the input tensors!"
assert (
tensor_device_grad == tensor_device
), "The gradient tensor should be the same as the input tensors!"
...@@ -10,22 +10,24 @@ import torch ...@@ -10,22 +10,24 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_recipes = [ fp8_recipes = [None]
None, # non-fp8 if fp8_available:
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet fp8_recipes.append(recipe.Float8CurrentScaling())
recipe.Float8CurrentScaling(), fp8_recipes.append(recipe.DelayedScaling())
recipe.DelayedScaling(),
]
SIZE = 512 model_config = {
NUM_HEADS = 8 "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
NUM_LAYERS = 5 }
EPSILON = 0.1 SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass # Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU. # that cannot be offloaded to CPU.
...@@ -124,11 +126,17 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: ...@@ -124,11 +126,17 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
model_cls = model_types[model_key] model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)] models_list = [model_cls() for _ in range(NUM_LAYERS)]
if fp8_recipe and not fp8_available: if model_key in ["multihead_attention", "transformer_layer"]:
pytest.skip(reason_for_no_fp8) available_backends, *_ = get_available_attention_backends(
if fp8_recipe is not None: model_config["small"],
if fp8_recipe.mxfp8() and not mxfp8_available: qkv_dtype=torch.bfloat16,
pytest.skip(reason_for_no_mxfp8) qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
without_offloading = _measure_memory_between_forward_and_backward( without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False models_list, fp8_recipe, False
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from dataclasses import dataclass from typing import Iterable, List, Union
import itertools
from typing import Iterable, List, Tuple, Union
import pytest import pytest
import torch import torch
...@@ -23,46 +21,32 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager ...@@ -23,46 +21,32 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import os import os
from functools import cache from functools import cache
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
FP8GlobalStateManager.is_fp8_block_scaling_available() mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Reset RNG states.
reset_rng_states()
# Record initial RNG state. model_configs = {
seed = 1234 "small": ModelConfig(32, 2, 2, 32),
torch.manual_seed(seed) }
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state() fp8_recipes = []
_cuda_rng_state = torch.cuda.get_rng_state() if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
@dataclass fp8_recipes.append(recipe.Float8BlockScaling())
class ModelConfig: if fp8_available:
"""Data tensor dimensions within Transformer model""" fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
sequence_length: int
batch_size: int
hidden_size: int
num_heads: int
kv_channels: int
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
# Supported data types # Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16] dtypes: List[torch.dtype] = [torch.float32, torch.float16]
...@@ -70,12 +54,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher ...@@ -70,12 +54,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16) dtypes.append(torch.bfloat16)
def reset_rng_states() -> None:
"""Revert to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_global_fp8_state(): def reset_global_fp8_state():
yield yield
...@@ -119,7 +97,7 @@ def generate_data( ...@@ -119,7 +97,7 @@ def generate_data(
"""Generate synthetic data.""" """Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn gen_func = torch.ones if warmup else torch.randn
return gen_func( return gen_func(
model_config.sequence_length, model_config.max_seqlen_q,
model_config.batch_size, model_config.batch_size,
model_config.hidden_size, model_config.hidden_size,
device="cuda", device="cuda",
...@@ -157,10 +135,12 @@ class _Sequential(torch.nn.Sequential): ...@@ -157,10 +135,12 @@ class _Sequential(torch.nn.Sequential):
# Supported modules # Supported modules
_test_cuda_graphs_modules: List[str] = [ _test_cuda_graphs_modules: List[str] = [
# Put linear first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
"linear",
"transformer", "transformer",
"layernorm_mlp", "layernorm_mlp",
"layernorm_linear", "layernorm_linear",
"linear",
"mha", "mha",
"linear_op", "linear_op",
] ]
...@@ -310,35 +290,27 @@ def _test_cuda_graphs( ...@@ -310,35 +290,27 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
def test_make_graphed_callables( def test_make_graphed_callables(
*, *,
module: str, module: str,
model_config: str = "small", model_config: str = "small",
num_layers: int = 3, num_layers: int = 3,
dtype: torch.dtype, dtype: torch.dtype,
fp8: bool,
fp8_params: bool, fp8_params: bool,
fp8_recipe: recipe.Recipe, fp8_recipe: recipe.Recipe,
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
) -> None: ) -> None:
# Skip invalid configurations. fp8 = fp8_recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8: if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8: if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
# Run model with different CUDA graph settings. # Run model with different CUDA graph settings.
model_config = model_configs[model_config] model_config = model_configs[model_config]
kwargs = dict( kwargs = dict(
...@@ -351,9 +323,11 @@ def test_make_graphed_callables( ...@@ -351,9 +323,11 @@ def test_make_graphed_callables(
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
) )
outputs = _test_cuda_graphs(graph_mode="none", **kwargs) # Put graphed callables first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
# Check that results match. # Check that results match.
assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode1)
...@@ -369,7 +343,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [ ...@@ -369,7 +343,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
] ]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"module", "module",
_test_make_graphed_callables_with_fp8_weight_caching_modules, _test_make_graphed_callables_with_fp8_weight_caching_modules,
...@@ -385,7 +358,6 @@ def test_make_graphed_callables_with_fp8_weight_caching( ...@@ -385,7 +358,6 @@ def test_make_graphed_callables_with_fp8_weight_caching(
test_make_graphed_callables( test_make_graphed_callables(
module=module, module=module,
dtype=torch.float32, dtype=torch.float32,
fp8=True,
fp8_params=fp8_params, fp8_params=fp8_params,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
fp8_weight_caching=True, fp8_weight_caching=True,
...@@ -401,7 +373,7 @@ def generate_data_for_dot_product_attention( ...@@ -401,7 +373,7 @@ def generate_data_for_dot_product_attention(
gen_func = torch.ones if warmup else torch.randn gen_func = torch.ones if warmup else torch.randn
return [ return [
gen_func( gen_func(
model_config.sequence_length, model_config.max_seqlen_q,
model_config.batch_size, model_config.batch_size,
model_config.num_heads, model_config.num_heads,
model_config.kv_channels, model_config.kv_channels,
...@@ -495,8 +467,8 @@ def _test_cuda_graphs_with_kwargs( ...@@ -495,8 +467,8 @@ def _test_cuda_graphs_with_kwargs(
( (
model_config.batch_size, model_config.batch_size,
1, 1,
model_config.sequence_length, model_config.max_seqlen_q,
model_config.sequence_length, model_config.max_seqlen_kv,
), ),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
...@@ -522,8 +494,8 @@ def _test_cuda_graphs_with_kwargs( ...@@ -522,8 +494,8 @@ def _test_cuda_graphs_with_kwargs(
( (
model_config.batch_size, model_config.batch_size,
1, 1,
model_config.sequence_length, model_config.max_seqlen_q,
model_config.sequence_length, model_config.max_seqlen_kv,
), ),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
......
...@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor: ...@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=dq_columnwise, columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=True, all_gather_usage=(block_scaling_dim == 1),
) )
self._test_quantize_dequantize( self._test_quantize_dequantize(
quantizer=quantizer, quantizer=quantizer,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from itertools import product
import copy import copy
from contextlib import nullcontext from contextlib import nullcontext
...@@ -112,13 +111,6 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -112,13 +111,6 @@ class TestFusedAdam(TestFusedOptimizer):
def test_bfloat16(self): def test_bfloat16(self):
self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True) self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
...@@ -530,13 +522,6 @@ class TestFusedSGD(TestFusedOptimizer): ...@@ -530,13 +522,6 @@ class TestFusedSGD(TestFusedOptimizer):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self): def __init__(self):
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import torch import torch
import math from typing import Optional
from typing import Optional, Dict
from transformer_engine.pytorch.router import ( from transformer_engine.pytorch.router import (
fused_topk_with_score_function, fused_topk_with_score_function,
fused_compute_score_for_moe_aux_loss, fused_compute_score_for_moe_aux_loss,
...@@ -149,11 +148,21 @@ def run_comparison( ...@@ -149,11 +148,21 @@ def run_comparison(
# Set some parameters # Set some parameters
if score_function == "sigmoid": if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function # Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
)
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else: else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 logits = (
torch.arange(
-num_tokens * num_experts // 2,
num_tokens * num_experts // 2,
device="cuda",
dtype=dtype,
)
* 1e-4
)
logits = logits.view(num_tokens, num_experts) logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True logits.requires_grad = True
if enable_bias and score_function == "sigmoid": if enable_bias and score_function == "sigmoid":
...@@ -282,11 +291,21 @@ def test_topk_softmax( ...@@ -282,11 +291,21 @@ def test_topk_softmax(
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if score_function == "sigmoid": if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function # Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
)
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else: else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 logits = (
torch.arange(
-num_tokens * num_experts // 2,
num_tokens * num_experts // 2,
device="cuda",
dtype=dtype,
)
* 1e-4
)
logits = logits.view(num_tokens, num_experts) logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True logits.requires_grad = True
...@@ -322,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f ...@@ -322,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
@pytest.mark.parametrize("topk", [4]) @pytest.mark.parametrize("topk", [4])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
# Construct the special probs to avoid inf in the sigmoid function # Construct the special probs to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
probs = probs.view(num_tokens, num_experts) probs = probs.view(num_tokens, num_experts)
probs.requires_grad = True probs.requires_grad = True
...@@ -380,15 +399,12 @@ def profile_topk_softmax( ...@@ -380,15 +399,12 @@ def profile_topk_softmax(
if __name__ == "__main__": if __name__ == "__main__":
test_fused_scores_for_aux_loss( test_topk_softmax(
dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax" dtype=torch.float32,
num_tokens=1024,
num_experts=128,
topk=4,
use_pre_softmax=False,
group_topk=None,
scaling_factor=None,
) )
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4)
...@@ -21,10 +21,12 @@ import transformer_engine.pytorch as te ...@@ -21,10 +21,12 @@ import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
BackwardBiasActivation, BackwardActivationBias,
BackwardLinearAdd, BackwardLinearAdd,
BackwardLinearScale,
ForwardLinearBiasActivation, ForwardLinearBiasActivation,
ForwardLinearBiasAdd, ForwardLinearBiasAdd,
ForwardLinearScaleAdd,
) )
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch.tensor.float8_tensor import (
...@@ -39,7 +41,7 @@ import transformer_engine_torch as tex ...@@ -39,7 +41,7 @@ import transformer_engine_torch as tex
# Import utility functions # Import utility functions
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent)) sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe from utils import dtype_tols, make_recipe, reset_rng_states
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
import os import os
...@@ -271,16 +273,72 @@ class TestSequentialContainer: ...@@ -271,16 +273,72 @@ class TestSequentialContainer:
model(torch.zeros(1)) model(torch.zeros(1))
assert len(model._module_groups) == 6 assert len(model._module_groups) == 6
def test_extra_tensors(self, size: int = 16) -> None:
"""Check that extra inputs are distributed properly between module groups
and that extra outputs are properly collected"""
# Construct sequential container
bias = te_ops.Bias(size=size, device="cpu")
with torch.no_grad():
bias.bias.copy_(torch.rand((size,)))
model = te_ops.Sequential( # | Inputs | Outputs
torch.nn.Identity(), # | x1 | x1
te_ops.MakeExtraOutput(in_place=True), # | x1 | x1 [x1]
bias, # | x1 | h1 (= x1 + b)
te_ops.MakeExtraOutput(in_place=True), # | h1 | h1 [h1]
te_ops.AddExtraInput(in_place=True), # | h1 [x2] | x2 (= x2 + h1)
te_ops.MakeExtraOutput(in_place=True), # | x2 | x2 [x2]
torch.nn.Identity(), # | x2 | x2
bias, # | x2 | h2 (= x2 + b)
te_ops.AddExtraInput(in_place=True), # | h2 [x3] | x3 (= x3 + h2)
te_ops.MakeExtraOutput(in_place=True), # | x3 | x3 [x3]
te_ops.AddExtraInput(in_place=True), # | x3 [x4] | x4 (= x4 + x3)
torch.nn.Identity(), # | x4 | x4
te_ops.Identity(), # | x4 | x4
te_ops.MakeExtraOutput(in_place=True), # | x4 | x4 [x4]
te_ops.Identity(), # | x4 | x4
)
# Create input tensors
x1 = torch.rand((size,))
x2 = torch.rand((size,))
x3 = torch.rand((size,))
x4 = torch.rand((size,))
# Save original input tensor values
x1_orig = x1.clone()
x2_orig = x2.clone()
x3_orig = x3.clone()
x4_orig = x4.clone()
# Run forward
ys = model(x1, x2, x3, x4)
# Check whether outputs match (x4, x1, h1, x2, x3, x4)
assert len(ys) == 6
assert ys[0].data_ptr() == x4.data_ptr()
assert ys[1].data_ptr() == x1.data_ptr()
assert ys[2].data_ptr() not in [x.data_ptr() for x in (x1, x2, x3, x4)]
assert ys[3].data_ptr() == x2.data_ptr()
assert ys[4].data_ptr() == x3.data_ptr()
assert ys[5].data_ptr() == x4.data_ptr()
# Check whether tensors have correct values
b = bias.bias
h1 = ys[2]
torch.testing.assert_close(x1, x1_orig)
torch.testing.assert_close(h1, x1_orig + b)
torch.testing.assert_close(x2, x2_orig + h1)
torch.testing.assert_close(x3, x3_orig + x2 + b)
torch.testing.assert_close(x4, x4_orig + x3)
class TestFuser: class TestFuser:
"""Tests for operation fusion infrastructure""" """Tests for operation fusion infrastructure"""
@staticmethod @staticmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# Configure RNG reset_rng_states()
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_scale_update( def test_fp8_scale_update(
...@@ -494,10 +552,7 @@ class TestBasicOps: ...@@ -494,10 +552,7 @@ class TestBasicOps:
@staticmethod @staticmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# Configure RNG reset_rng_states()
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
...@@ -795,10 +850,9 @@ class TestBasicOps: ...@@ -795,10 +850,9 @@ class TestBasicOps:
pytest.skip("FP8 output is only supported with FP8 GEMMs") pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantized_grad_input and not quantized_compute: if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs") pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output: if quantization not in (None, "fp8"):
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") if quantized_output or quantized_grad_input:
if quantization == "mxfp8" and quantized_grad_input: pytest.skip("Recipe does not support quantized GEMM output")
pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
if ( IS_HIP_EXTENSION and not use_hipblaslt() and if ( IS_HIP_EXTENSION and not use_hipblaslt() and
accumulate_into_main_grad and dtype != torch.float32 and not quantized_compute): accumulate_into_main_grad and dtype != torch.float32 and not quantized_compute):
pytest.skip("Parameters combination is not supported by ROCBLAS") pytest.skip("Parameters combination is not supported by ROCBLAS")
...@@ -1353,18 +1407,17 @@ class TestBasicOps: ...@@ -1353,18 +1407,17 @@ class TestBasicOps:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(y_test, y_ref, **tols)
# L2Norm backward pass requires slightly looser atol for bfloat16
if dtype == torch.bfloat16:
tols["atol"] = 2e-3
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("in_place", (True, False))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
def test_add_in_place( def test_add_extra_input(
self, self,
*, *,
in_shape: Iterable[int] = (32, 32), in_shape: Iterable[int] = (32, 32),
in_place: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
quantization: Optional[str], quantization: Optional[str],
...@@ -1410,7 +1463,7 @@ class TestBasicOps: ...@@ -1410,7 +1463,7 @@ class TestBasicOps:
dx2_ref = dy_ref dx2_ref = dy_ref
# Implementation with fusible operation # Implementation with fusible operation
op = te_ops.AddInPlace() op = te_ops.AddExtraInput(in_place=in_place)
y_test = op(x1_test, x2_test) y_test = op(x1_test, x2_test)
y_test.backward(dy_test) y_test.backward(dy_test)
...@@ -1425,6 +1478,7 @@ class TestBasicOps: ...@@ -1425,6 +1478,7 @@ class TestBasicOps:
torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0) torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0) torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)
@pytest.mark.parametrize("in_place", (True, False))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
...@@ -1432,6 +1486,7 @@ class TestBasicOps: ...@@ -1432,6 +1486,7 @@ class TestBasicOps:
self, self,
*, *,
in_shape: Iterable[int] = (32, 32), in_shape: Iterable[int] = (32, 32),
in_place: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
quantization: Optional[str], quantization: Optional[str],
...@@ -1477,7 +1532,7 @@ class TestBasicOps: ...@@ -1477,7 +1532,7 @@ class TestBasicOps:
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operation # Implementation with fusible operation
op = te_ops.MakeExtraOutput() op = te_ops.MakeExtraOutput(in_place=in_place)
y1_test, y2_test = op(x_test) y1_test, y2_test = op(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward() (y1_test * dy1_test + y2_test * dy2_test).sum().backward()
...@@ -1645,16 +1700,107 @@ class TestBasicOps: ...@@ -1645,16 +1700,107 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
def test_constant_scale(
self,
*,
scale: float,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
):
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = scale * x_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.ConstantScale(scale)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
@pytest.mark.parametrize("is_training", (True, False))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_dropout(
self,
*,
prob: float,
is_training: bool,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
):
# Random data
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
x_test = x_ref.clone().requires_grad_()
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
dy_test = dy_ref.clone()
# Apply dropout
op = te_ops.Dropout(prob)
if is_training:
op.train()
else:
op.eval()
y = op(x_test)
y.backward(dy_test)
# Check values
if is_training:
mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(y, x_ref * mask)
torch.testing.assert_close(x_test.grad, dy_ref * mask)
else:
torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)
# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
# mean p and standard deviation sqrt(p*(1-p)). By the central
# limit theorem, the mean of n iid Bernoulli variables
# converges to a normal random variable with mean p and
# standard deviation sqrt(p*(1-p)/n). If the observed mean is
# below the 0.5th or above the 99.5th percentiles, then the
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if is_training:
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"
class TestFusedOps: class TestFusedOps:
"""Tests for fused operations""" """Tests for fused operations"""
@staticmethod @staticmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# Configure RNG reset_rng_states()
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
...@@ -1841,7 +1987,7 @@ class TestFusedOps: ...@@ -1841,7 +1987,7 @@ class TestFusedOps:
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
te_ops.AddInPlace(), te_ops.AddExtraInput(in_place=True),
) )
with torch.no_grad(): with torch.no_grad():
model[0].weight.copy_(w_test) model[0].weight.copy_(w_test)
...@@ -1878,11 +2024,114 @@ class TestFusedOps: ...@@ -1878,11 +2024,114 @@ class TestFusedOps:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_forward_linear_scale_add(
self,
*,
scale: float,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool = False,
) -> None:
"""Forward GEMM + scale + add"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
x2_ref, x2_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x1_ref, w_ref) * scale + x2_ref
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=False,
device=device,
dtype=dtype,
),
te_ops.ConstantScale(scale),
te_ops.AddExtraInput(in_place=True),
te_ops.Quantize(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 2
assert isinstance(forward_ops[0][0], ForwardLinearScaleAdd)
assert isinstance(forward_ops[1][0], te_ops.Quantize)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu")) @pytest.mark.parametrize("activation", ("relu", "gelu"))
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32))) @pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_bias_activation( def test_backward_activation_bias(
self, self,
*, *,
activation: str, activation: str,
...@@ -1891,7 +2140,7 @@ class TestFusedOps: ...@@ -1891,7 +2140,7 @@ class TestFusedOps:
device: torch.device = "cuda", device: torch.device = "cuda",
quantization: Optional[str], quantization: Optional[str],
) -> None: ) -> None:
"""Backward dbias + dact + quantize""" """Backward dact + dbias + quantize"""
# Tensor dimensions # Tensor dimensions
in_shape = list(out_shape) in_shape = list(out_shape)
...@@ -1948,9 +2197,9 @@ class TestFusedOps: ...@@ -1948,9 +2197,9 @@ class TestFusedOps:
# Check that backward operations have been fused # Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops backward_ops = model._module_groups[0]._backward_ops
if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]: if with_quantization:
assert len(backward_ops) == 2 assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardBiasActivation) assert isinstance(backward_ops[0][0], BackwardActivationBias)
assert isinstance(backward_ops[1][0], te_ops.Quantize) assert isinstance(backward_ops[1][0], te_ops.Quantize)
else: else:
assert len(backward_ops) == 3 assert len(backward_ops) == 3
...@@ -1963,6 +2212,7 @@ class TestFusedOps: ...@@ -1963,6 +2212,7 @@ class TestFusedOps:
if with_quantization: if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3) tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu") y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu") db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
...@@ -2033,7 +2283,7 @@ class TestFusedOps: ...@@ -2033,7 +2283,7 @@ class TestFusedOps:
recipe = make_recipe(quantization) recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight): with te.fp8_model_init(enabled=quantized_weight):
model = te_ops.Sequential( model = te_ops.Sequential(
te_ops.MakeExtraOutput(), te_ops.MakeExtraOutput(in_place=True),
te_ops.Linear( te_ops.Linear(
in_features, in_features,
out_features, out_features,
...@@ -2071,16 +2321,106 @@ class TestFusedOps: ...@@ -2071,16 +2321,106 @@ class TestFusedOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_scale(
self,
*,
scale: float,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool = False,
) -> None:
"""Backward dgrad GEMM + scale"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) * scale
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=False,
device=device,
dtype=dtype,
),
te_ops.ConstantScale(scale),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
(y_test * dy_test).sum().backward()
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], BackwardLinearScale)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
class TestCheckpointing: class TestCheckpointing:
"""Tests for checkpointing""" """Tests for checkpointing"""
@staticmethod @staticmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# Configure RNG reset_rng_states()
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("quantized_weight", (False, True))
...@@ -2192,11 +2532,9 @@ class TestSequentialModules: ...@@ -2192,11 +2532,9 @@ class TestSequentialModules:
@staticmethod @staticmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# Configure RNG reset_rng_states()
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm")) @pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_compute", (False, True))
...@@ -2206,6 +2544,7 @@ class TestSequentialModules: ...@@ -2206,6 +2544,7 @@ class TestSequentialModules:
def test_layernorm_mlp( def test_layernorm_mlp(
self, self,
*, *,
requires_grad: bool,
bias: bool, bias: bool,
normalization: str, normalization: str,
quantized_compute: bool, quantized_compute: bool,
...@@ -2246,6 +2585,7 @@ class TestSequentialModules: ...@@ -2246,6 +2585,7 @@ class TestSequentialModules:
quantization=quantization, quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=requires_grad,
) )
_, dy_test = make_reference_and_test_tensors( _, dy_test = make_reference_and_test_tensors(
in_shape, in_shape,
......
...@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.utils import is_bf16_compatible
class SimpleTEModel(PreTrainedModel): class SimpleTEModel(PreTrainedModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment