Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
......@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None)
try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false"
else:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true"
yield
finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
os.environ.pop("NVTE_JAX_CUSTOM_CALLS")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter
os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter
......@@ -13,7 +13,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel
get_cu_seqlens_on_cp_rank,
)
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling
......
......@@ -4,12 +4,12 @@
import logging
import math
import os
from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager
import sys
import pathlib
from typing import Any, Dict, Tuple, Union
import pytest
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
......@@ -20,11 +20,8 @@ from transformer_engine.pytorch.attention.dot_product_attention import (
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils,
get_attention_backend,
check_set_window_size,
AttentionParams,
)
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
......@@ -49,21 +46,21 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
restore_from_saved,
)
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
reset_rng_states,
ModelConfig,
dtype_tols,
get_available_attention_backends,
)
# Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
# Reset RNG states
reset_rng_states()
@pytest.fixture(autouse=True)
......@@ -72,170 +69,20 @@ def reset_global_fp8_state():
fp8.FP8GlobalStateManager.reset()
class ModelConfig:
def __init__(
self,
batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim_qk: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
head_dim_v: int = None,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
self.hidden_size = num_heads * head_dim_qk
self.hidden_size_kv = num_gqa_groups * self.head_dim_v
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def _get_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_1_0": ModelConfig(8, 128, 16, 64),
"base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"base_2_0": ModelConfig(2, 2048, 24, 128),
"base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
"base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
"base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
"base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
"base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
"base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
"base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
"base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
}
......@@ -279,7 +126,7 @@ def test_dot_product_attention(
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......@@ -290,7 +137,7 @@ def test_dot_product_attention(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......@@ -411,62 +258,26 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
if IS_HIP_EXTENSION:
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
}
else:
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
"mla_3_2": ModelConfig(
8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
"mla_2_1": ModelConfig(
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
}
@pytest.mark.skipif(not IS_HIP_EXTENSION and get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
......@@ -477,40 +288,46 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
"mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
"mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
"mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
"mask_2_1": ModelConfig(
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right"
),
"mask_2_2": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
),
"mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
"mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
"mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"mask_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
),
"mask_5_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"),
"mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"),
"mask_7_0": ModelConfig(
2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
),
"mask_7_1": ModelConfig(
2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right"
),
"mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"),
"mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"),
"mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"),
"mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"),
"mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"),
"mask_10_0": ModelConfig(
2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
),
"mask_10_1": ModelConfig(
2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
),
}
......@@ -526,44 +343,102 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
"bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped
"bias_1_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
), # skipped
"bias_2_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
"bias_2_1": ModelConfig(
2,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
"bias_2_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
), # skipped
"bias_2_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
), # skipped
"bias_2_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
"bias_2_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
), # skipped
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"bias_3_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"bias_3_1": ModelConfig(
2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"bias_3_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"bias_3_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # skipped
"bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
"bias_3_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig(
4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"
4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped
"bias_4_1": ModelConfig(
2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"
2,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
"bias_4_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
), # skipped
"bias_4_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
), # skipped
"bias_4_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
), # skipped
"bias_4_5": ModelConfig(
2,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="alibi",
), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}
......@@ -578,33 +453,29 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
"bias_1_4": ModelConfig(
4,
16,
16,
64,
128,
2048,
24,
128,
0.0,
# mask, bias, bias_shape,
"no_mask",
"post_scale_bias",
bias_shape="11ss",
),
"bias_1_1": ModelConfig(
2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss"
),
"bias_1_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss"
),
"bias_1_3": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss"
),
"bias_1_4": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom"
attn_mask_type="causal",
attn_bias_type="alibi",
bias_shape="1hss",
alibi_type="custom",
),
"bias_1_5": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom"
2,
2048,
24,
128,
attn_mask_type="causal",
attn_bias_type="alibi",
bias_shape="bhss",
alibi_type="custom",
),
}
......@@ -620,34 +491,36 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_1": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
"swa_1_1": ModelConfig(2, 2048, 16, 64),
"swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
"swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
"swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"),
"swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
"swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"),
"swa_3_2": ModelConfig(
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right"
),
"swa_3_3": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right"
),
"swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"),
"swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
"swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"),
"swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"swa_6_2": ModelConfig(
2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right"
),
"swa_6_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
}
@pytest.mark.skipif((not IS_HIP_EXTENSION) and (not FlashAttentionUtils.v2_3_plus), reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
......@@ -658,18 +531,36 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_0": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
),
"alibi_1_1": ModelConfig(
1,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="causal",
attn_bias_type="alibi",
alibi_type="vanilla",
),
"alibi_2_0": ModelConfig(
2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom"
2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom"
),
"alibi_2_1": ModelConfig(
1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom"
1,
1024,
24,
128,
max_seqlen_kv=2048,
attn_mask_type="causal",
attn_bias_type="alibi",
alibi_type="custom",
),
}
@pytest.mark.skipif((not IS_HIP_EXTENSION) and (not FlashAttentionUtils.v2_3_plus), reason="Flash-attn 2.3+ is required.")
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
......@@ -694,16 +585,38 @@ qkv_layouts = [
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_0_0": ModelConfig(2, 128, 16, 64),
"layout_0_1": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
"layout_0_3": ModelConfig(
1,
128,
16,
64,
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
),
"layout_1_0": ModelConfig(2, 2048, 24, 128),
"layout_1_1": ModelConfig(
2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"layout_1_3": ModelConfig(
1,
2048,
24,
128,
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
),
"layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048),
"layout_2_1": ModelConfig(
2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
}
......@@ -720,55 +633,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_2_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
"layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
"layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"),
"layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"),
"layout_1_2": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"
),
"layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"),
"layout_2_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right"
),
"layout_2_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_3_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)),
"layout_3_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4)
),
"layout_3_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_4_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4)
),
"layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)),
"layout_4_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0)
),
"layout_4_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0)
),
"layout_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0)
),
"layout_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
2,
2048,
24,
128,
num_gqa_groups=1,
attn_mask_type="padding_causal_bottom_right",
window_size=(4, 0),
),
"layout_5_2": ModelConfig(
2,
24,
2048,
24,
128,
2048,
4096,
0.0,
"padding_causal_bottom_right",
"no_bias",
max_seqlen_kv=4096,
attn_mask_type="padding_causal_bottom_right",
window_size=(4, 0),
),
}
......@@ -1158,16 +1070,22 @@ def _run_dot_product_attention(
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
"te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
"te_1_1": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"te_1_2": ModelConfig(
2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
),
"te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
"te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"),
"te_2_1": ModelConfig(2, 2048, 16, 64),
"te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"),
"te_2_3": ModelConfig(
1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right"
),
"te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
"te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"),
}
......@@ -1189,26 +1107,27 @@ def test_transformer_layer(
tols = dict(atol=5e-2, rtol=5e-2)
workspace_opt = True
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd"
# override the qkv_layout in mqa gqa mode in ROCm TE
if IS_HIP_EXTENSION and model_configs[model].num_gqa_groups != model_configs[model].num_heads:
qkv_layout = "sbhd_sbhd_sbhd"
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
qkv_layout=(
qkv_format.replace("hd", "h3d")
if fused_qkv_params
else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
......@@ -1514,20 +1433,164 @@ def _run_transformer_layer(
return out, inp.grad
model_configs_fp8_extra_state = {
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs_fp8_extra_state[model]
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="sb3hd",
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported and not flash_attn_supported:
pytest.skip("No attention backend available.")
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_enabled,
fp8_mha=False,
)
reset_rng_states()
hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()
del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
"fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"),
"fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"),
"fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
"fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"),
"fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"),
"fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
"fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"),
"fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"),
"fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"),
"fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"),
"fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"),
"fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"),
}
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
......@@ -1561,7 +1624,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
)
)
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
......@@ -1576,18 +1639,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
9,
7,
0,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
# Test backend availability
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1613,11 +1688,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
if flash_attn_supported:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -1768,7 +1839,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
return out, param_names, tuple(x.grad for x in params)
return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
......@@ -1790,23 +1861,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
# if get_device_compute_capability() >= (10, 0):
# config.dropout_p = 0.1
if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
9,
7,
0,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
# Test backend availability
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_layout,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if flash_attn_supported + fused_attn_supported < 1:
pytest.skip("No FP8 attention backend available.")
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
......@@ -1835,11 +1917,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
if flash_attn_supported:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
......@@ -2013,21 +2091,21 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_1": ModelConfig(1, 512, 1, 64),
"fp8_2": ModelConfig(4, 512, 16, 64),
"fp8_3": ModelConfig(1, 2048, 1, 128),
"fp8_4": ModelConfig(2, 2048, 24, 128),
"fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"),
"fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"),
"fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"),
"fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm")
@pytest.mark.skipif(
(
get_cudnn_version() < (8, 9, 3)
......@@ -2049,6 +2127,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model]
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not (fused_attn_backends and unfused_attn_supported):
pytest.skip("Not enough backends to run this test with.")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
......
......@@ -4,6 +4,8 @@
import os
import subprocess
import sys
import pathlib
import pytest
import torch
......@@ -12,27 +14,29 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
from torch.utils.cpp_extension import IS_HIP_EXTENSION
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_1_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA
}
......@@ -44,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node=" + str(num_gpus_per_node),
]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
......@@ -94,37 +98,41 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
2,
4096,
12,
128,
num_gqa_groups=2,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
), # GQA
"cp_2_3": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
), # GQA
"cp_3_0": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
"cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="DTK not surpport fused attn for now, CP tests require sm80+.")
@pytest.mark.skipif(IS_HIP_EXTENSION or get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
......@@ -176,6 +184,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtypes[dtype],
qkv_layout="_".join([qkv_format] * 3),
window_size=config.window_size,
context_parallel=True,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("No attention backend available.")
subprocess.run(
get_bash_arguments(
......
......@@ -5,18 +5,14 @@
from collections import OrderedDict
from typing import List
import os
import sys
import pathlib
import logging
import math
import pytest
import torch
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
......@@ -34,26 +30,25 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible,
)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import (
ModelConfig,
reset_rng_states,
get_available_attention_backends,
)
# Reset RNG states
reset_rng_states()
param_types = [torch.float16]
if is_bf16_compatible():
param_types.append(torch.bfloat16)
model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"infer_0": ModelConfig(
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
# test: b, sq, hq, dqk,
"infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16),
"infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16),
}
qkv_formats = ["bshd", "sbhd", "thd"]
......@@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged:
qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......
......@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
kwargs["config_file"].write(LOG_QUANTIZED_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
_run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def test_log_expert_parallel(**kwargs):
"""
......
......@@ -24,22 +24,17 @@ def test_transformer_engine_no_config(feature_dirs):
# FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
# modify_tensor_enabled - False by default
# modify_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
# inspect_tensor_enabled - False by default
# inspect_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0
)
# inspect_tensor_postquantize - False by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -51,24 +46,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -80,22 +75,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0
)
)[0]
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0
)
)[0]
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -111,22 +106,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
# check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
)
)[0]
# check modify_tensor
......@@ -168,14 +163,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
gemm="wgrad",
tensor_name="gradient",
iteration=0,
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4",
gemm="fprop",
tensor_name="activation",
iteration=0,
)
)[0]
finally:
debug_api.end_debug()
......@@ -191,11 +186,11 @@ def test_fake_quant(configs_dir, feature_dirs):
# modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
)[0]
# modify_tensor
debug_api.transformer_engine.modify_tensor(
......@@ -218,11 +213,11 @@ def test_fake_quant(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
)[0]
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs):
)
tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor(
data=tensor.to(torch.uint8).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(),
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
)
tensor_fp8 = quantizer(tensor)
def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
......@@ -260,54 +254,64 @@ def test_statistics_collection(configs_dir, feature_dirs):
tensor_name="activation",
iteration=200,
tp_group=None,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)
)[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
expected_underflows = (
((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
)
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5)
assert debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
# TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
debug_api.transformer_engine.inspect_tensor_postquantize(
assert debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient",
iteration=200,
rowwise=True,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
......@@ -316,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs):
debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight",
iteration=200,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
......@@ -328,7 +335,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
)
)[0]
assert_empty()
finally:
......@@ -343,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
default_logging_enabled=False,
)
def feed(tensor, tensor_fp8):
def feed(tensor, tensor_fp8, quantizer):
debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=1,
tp_group=None,
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.5.mlp.fc1",
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
def log_stats():
......@@ -365,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return STATS_BUFFERS.log_stats()
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
def fp8_tensor(t):
return Float8Tensor(
data=t.to(torch.uint8).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
)
return quantizer(t.cuda())
shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0])
feed(tensors[1], tensors_fp8[1])
feed(tensors[0], tensors_fp8[0], quantizer)
feed(tensors[1], tensors_fp8[1], quantizer)
stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2)
feed(tensor2, fp8tensor2, quantizer)
stats2 = log_stats()
assert len(stats1.keys()) > 0
......
test:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 1
freq: 3
LogFp8TensorStats:
enabled: True
tensors: activation
stats: [underflows%]
start_step: 1
freq: 5
\ No newline at end of file
test:
enabled: True
layers:
layer_name_regex_pattern: .*1
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 0
freq: 100000
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import torch
import tempfile
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import RecipeState
import pytest
import contextlib
import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.debug.pytorch.debug_state import TEDebugState
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
LOG_QUANTIZED_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
recipes = [
"fp8_delayed_scaling",
"fp8_current_scaling",
"fp8_block_scaling",
"mxfp8",
]
bare_stats = [
"underflows%",
"scale_inv_min",
"scale_inv_max",
"mse",
]
all_stats = []
for r in recipes:
for stat in bare_stats:
for columnwise_postfix in ["", "_columnwise"]:
if (
r in ["fp8_current_scaling", "fp8_block_scaling"]
and torch.cuda.get_device_capability()[0] < 9
):
# hopper is needed for current-scaling, block-scaling
continue
if r == "mxfp8" and torch.cuda.get_device_capability()[0] < 10:
# blackwell is needed for mxfp8
continue
if (
r in ["fp8_delayed_scaling", "fp8_current_scaling"]
and columnwise_postfix == "_columnwise"
):
# columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling
continue
all_stats.append(f"{r}_{stat}{columnwise_postfix}")
all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows%
@contextlib.contextmanager
def debug_session(config_str: str, feature_dirs):
"""
Helper context manager that
1. writes the YAML `config_str` to a temporary file,
2. starts a debug session, and
3. yields the directory that contains the statistics log.
The session is closed automatically – even on exceptions – so every test
stays concise and leak-free.
"""
with tempfile.NamedTemporaryFile(
mode="w", delete=False
) as cfg_file, tempfile.TemporaryDirectory() as log_dir:
cfg_file.write(config_str)
cfg_file.flush()
debug_api.initialize(
config_file=cfg_file.name,
feature_dirs=feature_dirs,
log_dir=log_dir,
)
try:
yield log_dir
finally:
debug_api.end_debug()
def read_log(log_dir: str) -> str:
"""Return the content of the statistics log produced by `debug_session`."""
stat_path = os.path.join(
log_dir,
"nvdlfw_inspect_statistics_logs",
"nvdlfw_inspect_globalrank-0.log",
)
with open(stat_path, "r") as f:
return f.read()
def test_sanity(feature_dirs):
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert output, "Output is empty"
for stat in all_stats:
assert stat in output, f"Stat {stat} not found in output"
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
pytest.skip(reason_for_no_mxfp8)
if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling():
pytest.skip(reason_for_no_fp8_block_scaling)
log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats))
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
recipe_state = RecipeState.create(
fp8_recipe,
mode="forward",
num_quantizers=3,
)
tensor = torch.zeros(1024, 1024).cuda()
tensor[0, :] = 1000
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
for line in output.splitlines():
if "underflows%" in line:
underflows = float(line.split("value=")[1])
expected = (
((dequantized_tensor == 0).sum() - (tensor == 0).sum())
/ dequantized_tensor.numel()
* 100
)
assert underflows == pytest.approx(expected.cpu(), abs=1e-4)
if "mse" in line:
mse = float(line.split("value=")[1])
expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean")
assert mse == pytest.approx(expected.cpu(), abs=1e-6)
if "overflows%" in line:
overflows = float(line.split("value=")[1])
expected = (
(abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100
)
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
# non-quantized statistics should be logged every 3 iterations,
# and quantized statistics should be logged every 5 iterations.
with tempfile.TemporaryDirectory() as temp_dir:
debug_api.initialize(
config_file=configs_dir + "/log_config.yaml",
feature_dirs=feature_dirs,
log_dir=temp_dir,
)
if layer == "linear":
model = te.Linear(128, 128, name="linear1")
elif layer == "transformer":
model = te.TransformerLayer(128, 128, 4, name="transformer1")
else:
raise ValueError(f"Invalid layer: {layer}")
for i in range(20):
x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True):
y = model(x)
y.sum().backward()
debug_api.step()
with open(
os.path.join(
temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
),
"r",
) as f:
file_content = f.read()
for i in range(1, 20):
if i % 3 == 0 or i % 5 == 0:
assert f"iteration={i:06d}" in file_content
else:
assert f"iteration={i:06d}" not in file_content
debug_api.end_debug()
TEDebugState._reset()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import time
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs):
debug_api.end_debug()
TEDebugState._reset()
if debug_tools_initialized:
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api.initialize(
config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs
)
try:
if layer == "linear":
model = torch.nn.Sequential(
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2")
).cuda()
NUM_ITERS = 18000
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 2000
x = torch.randn(1, 1, 1).cuda()
y = model(x)
y.sum().backward()
debug_api.step()
torch.cuda.synchronize()
time_start = time.time()
for i in range(NUM_ITERS):
y = model(x)
y.sum().backward()
if debug_tools_initialized:
debug_api.step()
torch.cuda.synchronize()
time_end = time.time()
finally:
if debug_tools_initialized:
debug_api.end_debug()
return time_end - time_start
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_cpu_overhead(layer, configs_dir, feature_dirs):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs)
without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs)
print(f"with_debug_tools: {with_debug_tools} s")
print(f"without_debug_tools: {without_debug_tools} s")
assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin
......@@ -519,6 +519,7 @@ def _train(opts):
if opts.use_cuda_graphs:
del test_graph
torch.cuda.synchronize()
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib
import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch import TransformerLayer, Linear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig
model_configs = {
"small": ModelConfig(2, 10, 2, 16),
}
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"])
def test_current_device(model, module):
"""Test cases where current device is different from tensor device"""
num_devices = torch.cuda.device_count()
assert num_devices > 1, "This test requires more than one GPU!"
tensor_device = num_devices - 1
dtype = torch.bfloat16
config = model_configs[model]
args = []
kwargs = {}
bwd_args = []
if module == "TransformerLayer":
model = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
params_dtype=dtype,
attn_input_format="thd",
self_attn_mask_type="padding",
device=f"cuda:{tensor_device}",
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
(num_tokens, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention":
model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
)
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
args = [
torch.randn(
num_tokens,
config.num_heads,
config.head_dim_qk,
dtype=dtype,
device=tensor_device,
requires_grad=True,
)
for _ in range(3)
]
cu_seqlens_q, cu_seqlens_kv = [
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
]
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)]
elif module == "Linear":
model = Linear(
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
)
]
current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs)
if module == "DotProductAttention":
out.backward(*bwd_args)
else:
loss = out.sum()
loss.backward()
current_device_after = torch.cuda.current_device()
tensor_device_out = out.get_device()
tensor_device_grad = args[0].grad.get_device()
assert (
current_device_after == current_device_before
), "The current device should not have changed!"
assert (
tensor_device_out == tensor_device
), "The output tensor should be the same as the input tensors!"
assert (
tensor_device_grad == tensor_device
), "The gradient tensor should be the same as the input tensors!"
......@@ -10,22 +10,24 @@ import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_recipes = [
None, # non-fp8
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
recipe.Float8CurrentScaling(),
recipe.DelayedScaling(),
]
fp8_recipes = [None]
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
SIZE = 512
NUM_HEADS = 8
NUM_LAYERS = 5
EPSILON = 0.1
model_config = {
"small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1),
}
SIZE = model_config["small"].hidden_size
NUM_HEADS = model_config["small"].num_heads
NUM_LAYERS = model_config["small"].num_layers
EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
......@@ -124,11 +126,17 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
model_cls = model_types[model_key]
models_list = [model_cls() for _ in range(NUM_LAYERS)]
if fp8_recipe and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if model_key in ["multihead_attention", "transformer_layer"]:
available_backends, *_ = get_available_attention_backends(
model_config["small"],
qkv_dtype=torch.bfloat16,
qkv_layout="sbhd_sbhd_sbhd",
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
pytest.skip("Fused attention backend not available.")
os.environ["NVTE_FLASH_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
......
......@@ -2,9 +2,7 @@
#
# See LICENSE for license information.
from dataclasses import dataclass
import itertools
from typing import Iterable, List, Tuple, Union
from typing import Iterable, List, Union
import pytest
import torch
......@@ -23,46 +21,32 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe
from utils import ModelConfig, reset_rng_states
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
import os
from functools import cache
# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
@dataclass
class ModelConfig:
"""Data tensor dimensions within Transformer model"""
sequence_length: int
batch_size: int
hidden_size: int
num_heads: int
kv_channels: int
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
# Reset RNG states.
reset_rng_states()
model_configs = {
"small": ModelConfig(32, 2, 2, 32),
}
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
......@@ -70,12 +54,6 @@ if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16)
def reset_rng_states() -> None:
"""Revert to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
......@@ -119,7 +97,7 @@ def generate_data(
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
return gen_func(
model_config.sequence_length,
model_config.max_seqlen_q,
model_config.batch_size,
model_config.hidden_size,
device="cuda",
......@@ -157,10 +135,12 @@ class _Sequential(torch.nn.Sequential):
# Supported modules
_test_cuda_graphs_modules: List[str] = [
# Put linear first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
"linear",
"transformer",
"layernorm_mlp",
"layernorm_linear",
"linear",
"mha",
"linear_op",
]
......@@ -310,35 +290,27 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
def test_make_graphed_callables(
*,
module: str,
model_config: str = "small",
num_layers: int = 3,
dtype: torch.dtype,
fp8: bool,
fp8_params: bool,
fp8_recipe: recipe.Recipe,
fp8_weight_caching: bool = False,
) -> None:
# Skip invalid configurations.
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
fp8 = fp8_recipe is not None
if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and module == "linear_op":
if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict(
......@@ -351,9 +323,11 @@ def test_make_graphed_callables(
fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe,
)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
# Put graphed callables first to test the case where the cuda context might not be set in
# creating TMA descriptor for MXFP8 quantization.
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
# Check that results match.
assert_all_equal(outputs, graph_outputs_mode1)
......@@ -369,7 +343,6 @@ _test_make_graphed_callables_with_fp8_weight_caching_modules = [
]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
"module",
_test_make_graphed_callables_with_fp8_weight_caching_modules,
......@@ -385,7 +358,6 @@ def test_make_graphed_callables_with_fp8_weight_caching(
test_make_graphed_callables(
module=module,
dtype=torch.float32,
fp8=True,
fp8_params=fp8_params,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
......@@ -401,7 +373,7 @@ def generate_data_for_dot_product_attention(
gen_func = torch.ones if warmup else torch.randn
return [
gen_func(
model_config.sequence_length,
model_config.max_seqlen_q,
model_config.batch_size,
model_config.num_heads,
model_config.kv_channels,
......@@ -495,8 +467,8 @@ def _test_cuda_graphs_with_kwargs(
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
model_config.max_seqlen_q,
model_config.max_seqlen_kv,
),
dtype=torch.bool,
device="cuda",
......@@ -522,8 +494,8 @@ def _test_cuda_graphs_with_kwargs(
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
model_config.max_seqlen_q,
model_config.max_seqlen_kv,
),
dtype=torch.bool,
device="cuda",
......
......@@ -223,7 +223,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=True,
all_gather_usage=(block_scaling_dim == 1),
)
self._test_quantize_dequantize(
quantizer=quantizer,
......
......@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
from itertools import product
import copy
from contextlib import nullcontext
......@@ -112,13 +111,6 @@ class TestFusedAdam(TestFusedOptimizer):
def test_bfloat16(self):
self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
......@@ -530,13 +522,6 @@ class TestFusedSGD(TestFusedOptimizer):
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
class Model(torch.nn.Module):
def __init__(self):
......
......@@ -2,8 +2,7 @@
#
# See LICENSE for license information.
import torch
import math
from typing import Optional, Dict
from typing import Optional
from transformer_engine.pytorch.router import (
fused_topk_with_score_function,
fused_compute_score_for_moe_aux_loss,
......@@ -149,11 +148,21 @@ def run_comparison(
# Set some parameters
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
)
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = (
torch.arange(
-num_tokens * num_experts // 2,
num_tokens * num_experts // 2,
device="cuda",
dtype=dtype,
)
* 1e-4
)
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
......@@ -282,11 +291,21 @@ def test_topk_softmax(
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
)
logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = (
torch.arange(
-num_tokens * num_experts // 2,
num_tokens * num_experts // 2,
device="cuda",
dtype=dtype,
)
* 1e-4
)
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
......@@ -322,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
@pytest.mark.parametrize("topk", [4])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
# Construct the special probs to avoid inf in the sigmoid function
offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
probs = probs.view(num_tokens, num_experts)
probs.requires_grad = True
......@@ -380,15 +399,12 @@ def profile_topk_softmax(
if __name__ == "__main__":
test_fused_scores_for_aux_loss(
dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax"
test_topk_softmax(
dtype=torch.float32,
num_tokens=1024,
num_experts=128,
topk=4,
use_pre_softmax=False,
group_topk=None,
scaling_factor=None,
)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4)
......@@ -21,10 +21,12 @@ import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardBiasActivation,
BackwardActivationBias,
BackwardLinearAdd,
BackwardLinearScale,
ForwardLinearBiasActivation,
ForwardLinearBiasAdd,
ForwardLinearScaleAdd,
)
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
......@@ -39,7 +41,7 @@ import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe
from utils import dtype_tols, make_recipe, reset_rng_states
if IS_HIP_EXTENSION:
import os
......@@ -271,16 +273,72 @@ class TestSequentialContainer:
model(torch.zeros(1))
assert len(model._module_groups) == 6
def test_extra_tensors(self, size: int = 16) -> None:
"""Check that extra inputs are distributed properly between module groups
and that extra outputs are properly collected"""
# Construct sequential container
bias = te_ops.Bias(size=size, device="cpu")
with torch.no_grad():
bias.bias.copy_(torch.rand((size,)))
model = te_ops.Sequential( # | Inputs | Outputs
torch.nn.Identity(), # | x1 | x1
te_ops.MakeExtraOutput(in_place=True), # | x1 | x1 [x1]
bias, # | x1 | h1 (= x1 + b)
te_ops.MakeExtraOutput(in_place=True), # | h1 | h1 [h1]
te_ops.AddExtraInput(in_place=True), # | h1 [x2] | x2 (= x2 + h1)
te_ops.MakeExtraOutput(in_place=True), # | x2 | x2 [x2]
torch.nn.Identity(), # | x2 | x2
bias, # | x2 | h2 (= x2 + b)
te_ops.AddExtraInput(in_place=True), # | h2 [x3] | x3 (= x3 + h2)
te_ops.MakeExtraOutput(in_place=True), # | x3 | x3 [x3]
te_ops.AddExtraInput(in_place=True), # | x3 [x4] | x4 (= x4 + x3)
torch.nn.Identity(), # | x4 | x4
te_ops.Identity(), # | x4 | x4
te_ops.MakeExtraOutput(in_place=True), # | x4 | x4 [x4]
te_ops.Identity(), # | x4 | x4
)
# Create input tensors
x1 = torch.rand((size,))
x2 = torch.rand((size,))
x3 = torch.rand((size,))
x4 = torch.rand((size,))
# Save original input tensor values
x1_orig = x1.clone()
x2_orig = x2.clone()
x3_orig = x3.clone()
x4_orig = x4.clone()
# Run forward
ys = model(x1, x2, x3, x4)
# Check whether outputs match (x4, x1, h1, x2, x3, x4)
assert len(ys) == 6
assert ys[0].data_ptr() == x4.data_ptr()
assert ys[1].data_ptr() == x1.data_ptr()
assert ys[2].data_ptr() not in [x.data_ptr() for x in (x1, x2, x3, x4)]
assert ys[3].data_ptr() == x2.data_ptr()
assert ys[4].data_ptr() == x3.data_ptr()
assert ys[5].data_ptr() == x4.data_ptr()
# Check whether tensors have correct values
b = bias.bias
h1 = ys[2]
torch.testing.assert_close(x1, x1_orig)
torch.testing.assert_close(h1, x1_orig + b)
torch.testing.assert_close(x2, x2_orig + h1)
torch.testing.assert_close(x3, x3_orig + x2 + b)
torch.testing.assert_close(x4, x4_orig + x3)
class TestFuser:
"""Tests for operation fusion infrastructure"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_scale_update(
......@@ -494,10 +552,7 @@ class TestBasicOps:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
......@@ -795,10 +850,9 @@ class TestBasicOps:
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
if quantization == "mxfp8" and quantized_grad_input:
pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
if quantization not in (None, "fp8"):
if quantized_output or quantized_grad_input:
pytest.skip("Recipe does not support quantized GEMM output")
if ( IS_HIP_EXTENSION and not use_hipblaslt() and
accumulate_into_main_grad and dtype != torch.float32 and not quantized_compute):
pytest.skip("Parameters combination is not supported by ROCBLAS")
......@@ -1353,18 +1407,17 @@ class TestBasicOps:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
# L2Norm backward pass requires slightly looser atol for bfloat16
if dtype == torch.bfloat16:
tols["atol"] = 2e-3
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("in_place", (True, False))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list)
def test_add_in_place(
def test_add_extra_input(
self,
*,
in_shape: Iterable[int] = (32, 32),
in_place: bool,
dtype: torch.dtype,
device: torch.device,
quantization: Optional[str],
......@@ -1410,7 +1463,7 @@ class TestBasicOps:
dx2_ref = dy_ref
# Implementation with fusible operation
op = te_ops.AddInPlace()
op = te_ops.AddExtraInput(in_place=in_place)
y_test = op(x1_test, x2_test)
y_test.backward(dy_test)
......@@ -1425,6 +1478,7 @@ class TestBasicOps:
torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0)
torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0)
@pytest.mark.parametrize("in_place", (True, False))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list)
......@@ -1432,6 +1486,7 @@ class TestBasicOps:
self,
*,
in_shape: Iterable[int] = (32, 32),
in_place: bool,
dtype: torch.dtype,
device: torch.device,
quantization: Optional[str],
......@@ -1477,7 +1532,7 @@ class TestBasicOps:
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()
# Implementation with fusible operation
op = te_ops.MakeExtraOutput()
op = te_ops.MakeExtraOutput(in_place=in_place)
y1_test, y2_test = op(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()
......@@ -1645,16 +1700,107 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
def test_constant_scale(
self,
*,
scale: float,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
):
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = scale * x_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.ConstantScale(scale)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
@pytest.mark.parametrize("is_training", (True, False))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_dropout(
self,
*,
prob: float,
is_training: bool,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
):
# Random data
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
x_test = x_ref.clone().requires_grad_()
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
dy_test = dy_ref.clone()
# Apply dropout
op = te_ops.Dropout(prob)
if is_training:
op.train()
else:
op.eval()
y = op(x_test)
y.backward(dy_test)
# Check values
if is_training:
mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(y, x_ref * mask)
torch.testing.assert_close(x_test.grad, dy_ref * mask)
else:
torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)
# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
# mean p and standard deviation sqrt(p*(1-p)). By the central
# limit theorem, the mean of n iid Bernoulli variables
# converges to a normal random variable with mean p and
# standard deviation sqrt(p*(1-p)/n). If the observed mean is
# below the 0.5th or above the 99.5th percentiles, then the
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if is_training:
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"
class TestFusedOps:
"""Tests for fused operations"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
......@@ -1841,7 +1987,7 @@ class TestFusedOps:
device=device,
dtype=dtype,
),
te_ops.AddInPlace(),
te_ops.AddExtraInput(in_place=True),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
......@@ -1878,11 +2024,114 @@ class TestFusedOps:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_forward_linear_scale_add(
self,
*,
scale: float,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool = False,
) -> None:
"""Forward GEMM + scale + add"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
x2_ref, x2_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x1_ref, w_ref) * scale + x2_ref
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=False,
device=device,
dtype=dtype,
),
te_ops.ConstantScale(scale),
te_ops.AddExtraInput(in_place=True),
te_ops.Quantize(),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x1_test, x2_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
assert len(forward_ops) == 2
assert isinstance(forward_ops[0][0], ForwardLinearScaleAdd)
assert isinstance(forward_ops[1][0], te_ops.Quantize)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx1_test, x1_ref.grad, **tols)
torch.testing.assert_close(dx2_test, x2_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("activation", ("relu", "gelu"))
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_bias_activation(
def test_backward_activation_bias(
self,
*,
activation: str,
......@@ -1891,7 +2140,7 @@ class TestFusedOps:
device: torch.device = "cuda",
quantization: Optional[str],
) -> None:
"""Backward dbias + dact + quantize"""
"""Backward dact + dbias + quantize"""
# Tensor dimensions
in_shape = list(out_shape)
......@@ -1948,9 +2197,9 @@ class TestFusedOps:
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]:
if with_quantization:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardBiasActivation)
assert isinstance(backward_ops[0][0], BackwardActivationBias)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
else:
assert len(backward_ops) == 3
......@@ -1963,6 +2212,7 @@ class TestFusedOps:
if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
......@@ -2033,7 +2283,7 @@ class TestFusedOps:
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.MakeExtraOutput(),
te_ops.MakeExtraOutput(in_place=True),
te_ops.Linear(
in_features,
out_features,
......@@ -2071,16 +2321,106 @@ class TestFusedOps:
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_scale(
self,
*,
scale: float,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool = False,
) -> None:
"""Backward dgrad GEMM + scale"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) * scale
y_ref.backward(dy_ref)
# Implementation with fusible operations
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=False,
device=device,
dtype=dtype,
),
te_ops.ConstantScale(scale),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
(y_test * dy_test).sum().backward()
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], BackwardLinearScale)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
class TestCheckpointing:
"""Tests for checkpointing"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
......@@ -2192,11 +2532,9 @@ class TestSequentialModules:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True))
......@@ -2206,6 +2544,7 @@ class TestSequentialModules:
def test_layernorm_mlp(
self,
*,
requires_grad: bool,
bias: bool,
normalization: str,
quantized_compute: bool,
......@@ -2246,6 +2585,7 @@ class TestSequentialModules:
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=requires_grad,
)
_, dy_test = make_reference_and_test_tensors(
in_shape,
......
......@@ -7,7 +7,6 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.utils import is_bf16_compatible
class SimpleTEModel(PreTrainedModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment