Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
......@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
from collections import OrderedDict
import math
import os
from typing import Dict, List, Tuple, Optional
......@@ -39,54 +38,39 @@ from transformer_engine.pytorch import (
Fp8Unpadding,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
# Reset RNG states.
reset_rng_states()
if torch_version() >= (2, 7, 0):
torch._dynamo.config.recompile_limit = 16
else:
torch._dynamo.config.cache_size_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
"small": ModelConfig(1, 128, 8, 16, num_layers=4),
"126m": ModelConfig(1, 2048, 12, 64, num_layers=12),
}
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
"126m": ModelConfig(1, 256, 12, 64, num_layers=12),
}
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
......@@ -120,12 +104,27 @@ if NVTE_TEST_NVINSPECT_ENABLED:
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
):
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
)
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
def get_causal_attn_mask(sq: int) -> torch.Tensor:
......@@ -177,12 +176,6 @@ def assert_allclose(
raise AssertionError(msg)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
......@@ -535,13 +528,13 @@ def _test_e2e_selective_recompute(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
......@@ -550,13 +543,13 @@ def _test_e2e_selective_recompute(
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
te_out = block(
......@@ -582,14 +575,8 @@ def _test_e2e_selective_recompute(
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -630,13 +617,13 @@ def _test_e2e_full_recompute(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
......@@ -645,14 +632,14 @@ def _test_e2e_full_recompute(
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=use_reentrant,
)
if use_reentrant:
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if recompute:
......@@ -698,14 +685,8 @@ def _test_e2e_full_recompute(
def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -761,13 +742,13 @@ def _test_e2e_checkpointing_get_model(config, dtype):
return TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
......@@ -779,7 +760,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
reset_rng_states()
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -809,14 +790,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if p.requires_grad:
param_grads.append(p.grad.clone())
global _cpu_rng_state, _cuda_rng_state
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
del block
block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path, weights_only=False))
reset_rng_states()
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
for p in block.parameters():
if p.requires_grad:
......@@ -849,6 +830,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
if not is_fused_attn_available(config, dtype):
pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
......@@ -869,13 +852,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len)
inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
out = block(inp_hidden_states, attention_mask=inp_attn_mask)
loss = out.sum()
......@@ -895,11 +878,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
pytest.skip("No attention backend available.")
te_gpt = TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_attention_heads=config.num_heads,
layernorm_epsilon=config.eps,
attention_dropout=0.1,
hidden_dropout=0.1,
......@@ -914,7 +899,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
TorchGPT(
config.hidden_size,
config.eps,
config.num_attention_heads,
config.num_heads,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
......@@ -975,13 +960,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None
forward_kwargs = {}
if te:
......@@ -1006,10 +991,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
pytest.skip("No attention backend available.")
te_mha = MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
config.num_heads,
fuse_qkv_params=True,
params_dtype=dtype,
qkv_weight_interleaved=False,
......@@ -1020,7 +1007,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha = (
TorchMHA(
config.hidden_size,
config.num_attention_heads,
config.num_heads,
)
.to(dtype=dtype)
.cuda()
......@@ -1066,7 +1053,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -1098,11 +1085,12 @@ def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states()
mask = torch.triu(
torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"),
diagonal=1,
)
query, key, value = [
torch.randn(
(config.seq_len, bs, config.num_attention_heads, config.embed),
(config.max_seqlen_q, bs, config.num_heads, config.kv_channels),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -1131,8 +1119,8 @@ def test_dpa_accuracy(dtype, bs, model):
te_dpa = (
DotProductAttention(
config.num_attention_heads,
config.embed,
config.num_heads,
config.kv_channels,
attention_dropout=0.0, # disable dropout, FU uses rng differently
)
.to(dtype=dtype)
......@@ -1141,7 +1129,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa = (
TorchDotProductAttention(
config.embed,
config.kv_channels,
0.0, # dropout
)
.to(dtype=dtype)
......@@ -1267,8 +1255,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
# Should be bit-wise match
for _, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
......@@ -1280,17 +1268,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation = True
fp8_model_params = False
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -1653,14 +1636,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("bs", [2])
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation
dtype, bs, model, bias, fuse_wgrad_accumulation
):
config = model_configs[model]
......@@ -1669,7 +1650,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=True,
......@@ -1681,7 +1661,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
......@@ -1691,8 +1670,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
# Share params
with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
if bias:
......@@ -1730,7 +1708,7 @@ def _test_grouped_linear_accuracy(
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -1743,14 +1721,14 @@ def _test_grouped_linear_accuracy(
split_size = 16
if recipe.mxfp8():
split_size = 128
m = config.seq_len // split_size
m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * split_size
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms
else:
m_splits = torch.tensor([config.seq_len])
m_splits = torch.tensor([config.max_seqlen_q])
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, GroupedLinear):
......@@ -1806,17 +1784,11 @@ def test_grouped_linear_accuracy(
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -1908,19 +1880,13 @@ def test_grouped_linear_accuracy_save_original_input(
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -2074,14 +2040,14 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
(config.seq_len * bs, config.hidden_size),
(config.max_seqlen_q * bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)
m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, TorchGroupedLinearWithPadding):
......@@ -2124,17 +2090,11 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -2199,19 +2159,13 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
......@@ -2268,9 +2222,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
# Placeholders used for graph capture.
static_input = torch.randn(
config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
)
static_target = torch.randn(
config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype
)
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
......@@ -2334,7 +2290,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
block_args = (
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
)
block_kwargs = dict(
layernorm_epsilon=config.eps,
......@@ -2342,7 +2298,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -2377,13 +2333,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.embed,
kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
......@@ -2392,13 +2348,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
)
te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
......@@ -2418,14 +2374,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -2461,13 +2411,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_sbhd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -2482,13 +2432,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_bshd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -2500,13 +2450,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_thd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
kv_channels=config.embed,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
......@@ -2521,15 +2471,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
x_sbhd = torch.randn(
(config.seq_len, bs, config.hidden_size),
(config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
x_bshd = x_sbhd.transpose(0, 1).contiguous()
x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous()
x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q
# To make sure forward is also identical (just in case some module decides
# to act fancy)
......@@ -2556,167 +2506,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
x_thd,
cu_seqlens_q=x_thd_cumsum,
cu_seqlens_kv=x_thd_cumsum,
max_seqlen_q=config.seq_len,
max_seqlen_kv=config.seq_len,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
)
torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention"]:
pytest.skip("Not support FusedAttention")
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
if (
backend == "FusedAttention"
and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 12, 0)
):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
config = model_configs_inference[model_key]
S = config.seq_len
B = bs
H = config.num_attention_heads
D = config.hidden_size
head_size = config.embed
layer_number = 1
# Limits the max size of KV-cache
B_max = B
S_max = S
if module == "TransformerLayer":
model = TransformerLayer(
hidden_size=D,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
self_attn_mask_type="causal",
enc_dec_attn_mask_type="causal",
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
device="cuda",
).eval()
else:
model = (
MultiheadAttention(
hidden_size=D,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
attn_mask_type="causal",
params_dtype=dtype,
)
.cuda()
.eval()
)
inference_params = InferenceParams(
max_batch_size=B_max,
max_sequence_length=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
is_paged=is_paged,
total_num_pages=int(B_max * S_max / 256),
page_size=256,
)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
if input_format == "bshd":
input = input.transpose(0, 1).contiguous()
incremental_output = torch.zeros_like(input)
# Generate output for the entire sequence
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
step_dict = OrderedDict(zip(list(range(B)), [1] * B))
for i in range(S):
inference_params.pre_step(step_dict)
if input_format == "sbhd":
incremental_input = input[i].view(1, B, D)
else:
incremental_input = input[:, i, :].view(B, 1, D)
seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
mask_type = "padding"
kwargs = {}
if module == "TransformerLayer":
kwargs["self_attn_mask_type"] = mask_type
else:
kwargs["attn_mask_type"] = mask_type
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None,
**kwargs,
max_seqlen_q=1,
max_seqlen_kv=S,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
)
if input_format == "sbhd":
incremental_output[i, :, :] = line_output.view(B, D)
else:
incremental_output[:, i, :] = line_output.view(B, D)
if module == "TransformerLayer":
atol = {
torch.float32: 5e-3,
torch.half: 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
torch.float32: 1e-3,
torch.half: 1e-3,
torch.bfloat16: 1e-2,
}
# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
@pytest.mark.parametrize(
"shape",
......@@ -2815,9 +2613,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
(16, 4096, 128, 512),
],
)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
def test_fp8_grouped_gemm(shape, accumulate):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
......
......@@ -27,7 +27,6 @@ import warnings
import numpy as np
import onnxruntime as ort
import torch
import random
from torch import nn as nn
from typing import Optional, Union, Tuple, List
from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
......@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
fp8_recipes = [
None,
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
]
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
......@@ -369,14 +367,6 @@ def validate_result(
)
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
if fake_bf16_io:
assert dtype == torch.bfloat16
......@@ -413,36 +403,12 @@ Test cases begin here.
"""
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(torch.bfloat16, True),
],
)
def test_export_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
precision: torch.dtype,
def _test_export_linear(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_bias: bool = True,
return_bias: bool = False,
precision: torch.dtype = torch.float32,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
......@@ -498,32 +464,28 @@ def test_export_linear(
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize(
"precision",
[
torch.float32,
torch.float16,
torch.bfloat16,
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
_test_export_linear(fp8_recipe=fp8_recipe, precision=precision)
@pytest.mark.parametrize("use_bias", [True, False])
def test_export_linear_use_bias(seed_default_rng, use_bias):
_test_export_linear(use_bias=use_bias)
@pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias):
_test_export_linear(return_bias=return_bias)
def _test_export_layernorm(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
normalization: str = all_normalizations[0],
):
# Set dimensions (these are arbitrary).
batch_size = 4
in_features = 64
......@@ -564,39 +526,31 @@ def test_export_layernorm(
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_recipe(seed_default_rng, fp8_recipe, precision):
_test_export_layernorm(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_zero_centered_gamma(seed_default_rng):
_test_export_layernorm(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
def test_export_layernorm_normalization(seed_default_rng, normalization):
_test_export_layernorm(normalization=normalization)
def _test_export_layernorm_linear(
scale_factor: float = 112,
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_bias: bool = True,
return_bias: bool = False,
return_layernorm_output: bool = False,
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
normalization: str = all_normalizations[0],
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
......@@ -644,41 +598,44 @@ def test_export_layernorm_linear(
)
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False])
@pytest.mark.parametrize("return_layernorm_output", [True, False])
@pytest.mark.parametrize(
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_mlp(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool,
precision: torch.dtype,
zero_centered_gamma: bool,
activation: str,
normalization: str,
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
_test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_linear_return_ln_out(seed_default_rng):
_test_export_layernorm_linear(return_layernorm_output=True)
def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
_test_export_layernorm_linear(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
_test_export_layernorm_linear(normalization=normalization)
def test_export_layernorm_linear_no_bias(seed_default_rng):
_test_export_layernorm_linear(use_bias=False)
def test_export_layernorm_linear_return_bias(seed_default_rng):
_test_export_layernorm_linear(return_bias=True)
def _test_export_layernorm_mlp(
scale_factor: float = 112,
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_bias: bool = True,
return_bias: bool = False,
return_layernorm_output: bool = False,
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
activation: str = supported_activations[0],
normalization: str = all_normalizations[0],
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled")
......@@ -720,6 +677,38 @@ def test_export_layernorm_mlp(
)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
_test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
_test_export_layernorm_mlp(return_layernorm_output=True)
def test_export_layernorm_mlp_return_bias(seed_default_rng):
_test_export_layernorm_mlp(return_bias=True)
def test_export_layernorm_mlp_no_bias(seed_default_rng):
_test_export_layernorm_mlp(use_bias=False)
def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
_test_export_layernorm_mlp(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
_test_export_layernorm_mlp(normalization=normalization)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation)
@pytest.mark.parametrize(
"precision, use_mask, attn_mask_type",
[
......@@ -734,8 +723,6 @@ def test_export_layernorm_mlp(
],
)
def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
......@@ -777,11 +764,6 @@ def test_export_core_attention(
)
test_configs_multihead_attention = [
# "use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
# "input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True),
......@@ -795,31 +777,14 @@ test_configs_attention_type = [
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
"input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type
)
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
precision: torch.dtype,
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool,
def _test_export_multihead_attention(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_mask: bool = True,
precision: torch.dtype = torch.float32,
input_layernorm: bool = True,
attention_type: str = "self",
fuse_qkv_params: bool = True,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
hidden_size = 256
sequence_length = 128
batch_size = 4
......@@ -837,6 +802,7 @@ def test_export_multihead_attention(
init_method,
output_layer_init_method,
)
attn_mask_type = "arbitrary" if use_mask else "no_mask"
hidden_states_context = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
......@@ -868,7 +834,7 @@ def test_export_multihead_attention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
return_layernorm_output=False,
input_layernorm=input_layernorm,
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
......@@ -960,30 +926,37 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
def test_export_transformer_layer(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
zero_centered_gamma: bool,
activation: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
def test_export_multihead_attention_recipe(fp8_recipe, precision):
_test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)
def test_export_multihead_attention_no_mask():
_test_export_multihead_attention(use_mask=False)
def test_export_multihead_attention_no_input_layernorm():
_test_export_multihead_attention(input_layernorm=False)
def test_export_multihead_attention_cross_attn():
_test_export_multihead_attention(attention_type="cross")
def test_export_multihead_attention_unfused_qkv_params():
_test_export_multihead_attention(fuse_qkv_params=False)
def _test_export_transformer_layer(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_mask: bool = True,
attn_mask_type: str = "arbitrary",
output_layernorm: bool = False,
precision: torch.dtype = torch.float32,
fuse_qkv_params: bool = True,
zero_centered_gamma: bool = False,
activation: str = supported_activations[0],
):
# Layer configuration
hidden_size = 64
sequence_length = 128
......@@ -1043,28 +1016,43 @@ def test_export_transformer_layer(
)
@skip_FP8
@skip_MXFP8
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision):
_test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)
def test_export_transformer_layer_no_mask():
_test_export_transformer_layer(use_mask=False)
def test_export_transformer_layer_output_layernorm():
_test_export_transformer_layer(output_layernorm=True)
def test_export_transformer_layer_unfused_qkv_params():
_test_export_transformer_layer(fuse_qkv_params=False)
def test_export_transformer_layer_zero_centered_gamma():
_test_export_transformer_layer(zero_centered_gamma=True)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation):
_test_export_transformer_layer(activation=activation)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
):
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Layer configuration
hidden_size = 64
sequence_length = 128
......@@ -1091,7 +1079,6 @@ def test_export_gpt_generation(
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda")
# "Context phase": use full input sequence length
......
......@@ -3,7 +3,6 @@
# See LICENSE for license information.
import random
import pytest
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
......
......@@ -8,10 +8,10 @@ import pytest
import torch
@pytest.mark.parametrize("use_qk_norm", [False, True])
@pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"])
@pytest.mark.parametrize("attention_type", ["self", "cross"])
@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5])
def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None:
def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None:
"""Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size = 256
num_attention_heads = 8
......@@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_type=attention_type,
use_qk_norm=use_qk_norm,
qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps,
bias=False,
device="cuda",
).cuda()
# Check module structure based on use_qk_norm parameter
if use_qk_norm:
assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True"
assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module"
assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module"
# Check that the module is L2Norm type
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
assert isinstance(
mha.qk_norm, L2Normalization
), "qk_norm should be an L2Normalization module"
# Check module structure based on qk_norm_type parameter
if qk_norm_type is not None:
assert mha.q_norm is not None, "Should have q_norm module when qk_norm_type is not None"
assert mha.k_norm is not None, "Should have k_norm module when qk_norm_type is not None"
# Check that the modules are of the correct type
if qk_norm_type == "L2Normalization":
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
assert isinstance(
mha.q_norm, L2Normalization
), "q_norm should be an L2Normalization module"
assert isinstance(
mha.k_norm, L2Normalization
), "k_norm should be an L2Normalization module"
# For L2 normalization, q_norm and k_norm should be the same instance (parameter-free)
assert (
mha.q_norm is mha.k_norm
), "q_norm and k_norm should be the same instance for L2 normalization"
elif qk_norm_type == "RMSNorm":
from transformer_engine.pytorch.module.rmsnorm import RMSNorm
assert isinstance(mha.q_norm, RMSNorm), "q_norm should be an RMSNorm module"
assert isinstance(mha.k_norm, RMSNorm), "k_norm should be an RMSNorm module"
# For RMS normalization, q_norm and k_norm should be separate instances
assert (
mha.q_norm is not mha.k_norm
), "q_norm and k_norm should be separate instances for RMS normalization"
elif qk_norm_type == "LayerNorm":
from transformer_engine.pytorch.module.layernorm import LayerNorm
assert isinstance(mha.q_norm, LayerNorm), "q_norm should be a LayerNorm module"
assert isinstance(mha.k_norm, LayerNorm), "k_norm should be a LayerNorm module"
# For LayerNorm, q_norm and k_norm should be separate instances
assert (
mha.q_norm is not mha.k_norm
), "q_norm and k_norm should be separate instances for LayerNorm"
else:
# For extensibility - just ensure they exist
assert mha.q_norm is not None, f"q_norm should exist for qk_norm_type={qk_norm_type}"
assert mha.k_norm is not None, f"k_norm should exist for qk_norm_type={qk_norm_type}"
else:
assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False"
assert mha.q_norm is None, "Should not have q_norm module when qk_norm_type is None"
assert mha.k_norm is None, "Should not have k_norm module when qk_norm_type is None"
# Create input tensors
batch_size = 2 # Use a fixed batch size for testing
......@@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf"
def test_qk_norm_output_difference() -> None:
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_output_difference(qk_norm_type) -> None:
"""Test that QK normalization actually changes the output compared to no normalization."""
hidden_size = 256
num_attention_heads = 8
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
......@@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None:
mha_with_norm = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=True,
qk_norm_type=qk_norm_type,
bias=False,
device="cuda",
).cuda()
......@@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None:
mha_no_norm = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=False,
qk_norm_type=None,
bias=False,
device="cuda",
).cuda()
......@@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None:
# Outputs should be different when QK normalization is enabled
assert not torch.allclose(
output_with_norm, output_no_norm, atol=1e-6
), "QK normalization should change the output, but outputs are identical"
), f"QK normalization ({qk_norm_type}) should change the output, but outputs are identical"
def test_qk_norm_with_fused_qkv() -> None:
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_with_fused_qkv(qk_norm_type) -> None:
"""Test QK normalization works with fused QKV parameters."""
hidden_size = 256
num_attention_heads = 8
......@@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None:
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
fuse_qkv_params=True,
use_qk_norm=True,
qk_norm_type=qk_norm_type,
bias=False,
device="cuda",
).cuda()
......@@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None:
), f"Output shape mismatch: {output.shape}"
def test_qk_norm_transformer_layer_output_difference() -> None:
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_transformer_layer_output_difference(qk_norm_type) -> None:
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
from transformer_engine.pytorch import TransformerLayer
......@@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
......@@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=True,
qk_norm_type=qk_norm_type,
bias=False,
device="cuda",
).cuda()
......@@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
use_qk_norm=False,
qk_norm_type=None,
bias=False,
device="cuda",
).cuda()
......@@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
output_no_norm = transformer_no_norm(hidden_states)
# Outputs should be different when QK normalization is enabled
assert not torch.allclose(
output_with_norm, output_no_norm, atol=1e-6
), "QK normalization should change the TransformerLayer output, but outputs are identical"
assert not torch.allclose(output_with_norm, output_no_norm, atol=1e-6), (
f"QK normalization ({qk_norm_type}) should change the TransformerLayer output, but outputs"
" are identical"
)
# Check that outputs have expected shapes and properties
assert output_with_norm.shape == (
......@@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf"
assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN"
assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf"
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_before_after_rope(qk_norm_type) -> None:
"""Test that QK normalization before and after RoPE works without errors."""
hidden_size = 256
num_attention_heads = 8
seq_len = 64
batch_size = 2
# Create model with QK norm after RoPE (default)
mha_after = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type=qk_norm_type,
qk_norm_before_rope=False,
bias=False,
device="cuda",
).cuda()
# Create model with QK norm before RoPE
mha_before = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type=qk_norm_type,
qk_norm_before_rope=True,
bias=False,
device="cuda",
).cuda()
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Create RoPE embeddings
head_dim = hidden_size // num_attention_heads
rotary_dim = head_dim // 2
rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)
with torch.no_grad():
output_after_rope = mha_after(hidden_states, rotary_pos_emb=rotary_pos_emb)
output_before_rope = mha_before(hidden_states, rotary_pos_emb=rotary_pos_emb)
output_after_no_rope = mha_after(hidden_states)
output_before_no_rope = mha_before(hidden_states)
# Check output shapes and properties
expected_shape = (seq_len, batch_size, hidden_size)
for output in [
output_after_rope,
output_before_rope,
output_after_no_rope,
output_before_no_rope,
]:
assert output.shape == expected_shape, f"Output shape mismatch: {output.shape}"
assert not torch.isnan(output).any(), "Output contains NaN"
assert not torch.isinf(output).any(), "Output contains Inf"
assert output_after_rope.shape == output_before_rope.shape, "Outputs should have same shape"
assert mha_after.qk_norm_before_rope == False, "mha_after should have qk_norm_before_rope=False"
assert mha_before.qk_norm_before_rope == True, "mha_before should have qk_norm_before_rope=True"
def test_different_qk_norm_types_produce_different_outputs() -> None:
"""Test that different QK normalization types produce different outputs."""
hidden_size = 256
num_attention_heads = 8
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create model with L2 normalization
mha_l2 = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type="L2Normalization",
bias=False,
device="cuda",
).cuda()
# Reset to same seed for identical initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create model with RMS normalization
mha_rms = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type="RMSNorm",
bias=False,
device="cuda",
).cuda()
# Create input tensors
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Compare outputs with identical weights but different QK norm types
with torch.no_grad():
output_l2 = mha_l2(hidden_states)
output_rms = mha_rms(hidden_states)
# Outputs should be different when using different normalization types
assert not torch.allclose(
output_l2, output_rms, atol=1e-6
), "L2 and RMS normalization should produce different outputs, but outputs are identical"
# Check that outputs have expected shapes and properties
assert output_l2.shape == output_rms.shape, "L2 and RMS outputs should have same shape"
assert not torch.isnan(output_l2).any(), "L2 output contains NaN"
assert not torch.isinf(output_l2).any(), "L2 output contains Inf"
assert not torch.isnan(output_rms).any(), "RMS output contains NaN"
assert not torch.isinf(output_rms).any(), "RMS output contains Inf"
......@@ -192,12 +192,6 @@ class TestFP8Recipe:
amax_compute_algo=amax_compute_algo,
)
# Get FP8 meta tensors
with te.fp8_autocast(fp8_recipe=recipe):
x_fp8_meta = op.get_quantizer("forward", 0)
w_fp8_meta = op.get_quantizer("forward", 1)
dy_fp8_meta = op.get_quantizer("backward", 0)
# Perform training steps
x_history = []
w_history = []
......@@ -229,19 +223,30 @@ class TestFP8Recipe:
y = op(x)
y.backward(dy)
def check_amax_history(
fp8_meta: dict,
ref_amax_history: Iterable[float],
) -> None:
"""Check that amax history matches expected values"""
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-amax_history_len:]
def check_metas(
test_scale: float,
test_amax_history: torch.Tensor,
ref_amax_history_list: list[float],
stage: str,
):
"""Check that meta tensors match expected values"""
# Compute amax
if len(ref_amax_history_list) > amax_history_len:
ref_amax_history_list = ref_amax_history_list[-(amax_history_len + 1) :]
ref_amax_history = torch.tensor(
ref_amax_history,
ref_amax_history_list,
dtype=torch.float32,
device=device,
)
test_amax_history = fp8_meta.amax_history[:, 0]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history_list)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history_list[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compare amax history
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(
test_amax_history[-(step + 1) :],
......@@ -249,23 +254,6 @@ class TestFP8Recipe:
**tols,
)
def check_scale(
quantizer: Float8Quantizer,
ref_amax_history: Iterable[float],
stage: str,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-(amax_history_len + 1) :]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compute scale
max_val = {
"forward": 448.0,
......@@ -273,16 +261,26 @@ class TestFP8Recipe:
}[stage]
ref_scale = (max_val / ref_amax) / (2**margin)
# Check values in FP8 meta tensors
# Compare scale
torch.testing.assert_close(
quantizer.scale.item(),
test_scale,
ref_scale,
)
# Get scaling factors
x_test_scale = op.get_quantizer("forward", 0).scale.item()
w_test_scale = op.get_quantizer("forward", 1).scale.item()
dy_test_scale = op.get_quantizer("backward", 0).scale.item()
# Get amax histories
x_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 0]
w_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 1]
dy_test_history = op._fp8_metas["backward"][backward_key].amax_history[:, 0]
# Check that results match expected values
check_scale(x_fp8_meta, x_history, "forward")
check_scale(w_fp8_meta, w_history, "forward")
check_scale(dy_fp8_meta, dy_history, "backward")
check_metas(x_test_scale, x_test_history, x_history, "forward")
check_metas(w_test_scale, w_test_history, w_history, "forward")
check_metas(dy_test_scale, dy_test_history, dy_history, "backward")
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
......
......@@ -2,9 +2,7 @@
#
# See LICENSE for license information.
from dataclasses import dataclass
from typing import Optional
from contextlib import nullcontext
import torch
import pytest
......@@ -18,11 +16,9 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init,
)
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
get_cudnn_version,
)
from transformer_engine.pytorch import (
LayerNormLinear,
......@@ -32,7 +28,6 @@ from transformer_engine.pytorch import (
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
)
from transformer_engine.common import recipe
import transformer_engine_torch as tex
......@@ -47,21 +42,17 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from utils import dtype_tols
from utils import ModelConfig
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
......@@ -79,88 +70,33 @@ if NVTE_TEST_NVINSPECT_ENABLED:
)
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
if IS_HIP_EXTENSION:
from functools import cache
@cache
def use_hipblaslt() -> bool:
return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None )
def custom_amax_to_scale(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: torch.Tensor,
recipe: recipe.DelayedScaling,
) -> torch.Tensor:
"""Custom func to test recipe."""
sf = fp8_max / amax
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
return sf
def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe."""
return torch.min(amax_history, dim=0).values
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
global _cpu_rng_state, _cuda_rng_state
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@dataclass
class ModelConfig:
"""Transformer model configuration"""
num_layers: int
seq_len: int
batch_size: int
hidden_size: int
num_attention_heads: int
kv_channels: Optional[int] = None
def is_fp8_supported(self):
if self.seq_len * self.batch_size % 16:
return False
if self.hidden_size % 16:
return False
return True
def is_fp8_supported(config: ModelConfig):
if (
config.max_seqlen_q * config.batch_size % 16
or config.max_seqlen_kv * config.batch_size % 16
):
return False
if config.hidden_size % 16 or config.hidden_size_kv % 16:
return False
return True
model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
"weird": ModelConfig(2, 37, 3, 69, 3),
"large": ModelConfig(1, 128, 2, 512, 4, 128),
"126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
"small": ModelConfig(2, 32, 2, 32, num_layers=2),
"weird": ModelConfig(3, 37, 3, 23, num_layers=2),
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
fp8_recipes = [
None, # Test non-FP8
recipe.MXFP8BlockScaling(), # Test default
recipe.Float8CurrentScaling(), # Test default
recipe.Float8BlockScaling(), # Test default
recipe.DelayedScaling(), # Test default
recipe.DelayedScaling( # Test most_recent algo
amax_history_len=16,
amax_compute_algo="most_recent",
),
recipe.DelayedScaling( # Test custom amax and scale compute algo
fp8_format=recipe.Format.E4M3,
amax_compute_algo=custom_amax_compute,
scaling_factor_compute_algo=custom_amax_to_scale,
),
]
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
......@@ -184,66 +120,9 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset()
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(
config.seq_len,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
requires_grad=True,
)
static_target = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
use_fp8 = fp8_recipe is not None
if skip_wgrad:
_disable_wgrads(block)
# Pre graph capture warmup in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# Capture.
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
static_output = block(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
optimizer.step()
# Fills the graph's input memory with new data to compute on
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
g.replay()
torch.cuda.synchronize()
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=torch.float32,
device="cuda",
requires_grad=True,
......@@ -251,7 +130,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
(1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
......@@ -278,14 +157,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
(1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
......@@ -316,9 +195,9 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -327,16 +206,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
if skip_wgrad:
_disable_wgrads(block)
if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
......@@ -344,7 +216,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
......@@ -352,7 +224,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
(config.batch_size, 1, 1, config.max_seqlen_q),
dtype=torch.bool,
device="cuda",
)
......@@ -370,21 +242,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
(1, 1, config.seq_len, config.seq_len),
(1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
enc_dec_attn_mask = torch.randint(
2,
(config.batch_size, 1, 1, config.seq_len),
(config.batch_size, 1, 1, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
......@@ -412,7 +284,7 @@ def _test_sanity_common(
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=not skip_dgrad,
......@@ -440,7 +312,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
(config.max_seqlen_q, config.batch_size, config.hidden_size),
device="cuda",
requires_grad=True,
)
......@@ -495,13 +367,7 @@ def test_sanity_layernorm_linear(
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -529,13 +395,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -562,16 +422,10 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest.skip("Quantized model parameters are not supported in debug mode.")
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs * config.seq_len
num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
......@@ -607,16 +461,10 @@ def test_sanity_grouped_linear(
ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs = bs * 16
num_tokens = bs * config.seq_len * (num_gemms - 1)
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
......@@ -628,7 +476,7 @@ def test_sanity_grouped_linear(
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
m_splits = [bs * config.seq_len] * num_gemms
m_splits = [bs * config.max_seqlen_q] * num_gemms
if empty_split == "first":
m_splits[0] = 0
elif empty_split == "last":
......@@ -666,13 +514,7 @@ def test_sanity_layernorm_mlp(
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -697,36 +539,24 @@ def test_sanity_layernorm_mlp(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("activation", ["gelu", "swiglu"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(
dtype,
fp8_recipe,
model,
skip_wgrad,
zero_centered_gamma,
bias,
activation,
normalization,
parallel_attention_mlp,
cpu_offload,
):
if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("CPU offload is not supported in debug mode.")
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -736,7 +566,7 @@ def test_sanity_gpt(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -745,7 +575,6 @@ def test_sanity_gpt(
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
activation=activation,
normalization=normalization,
......@@ -753,7 +582,7 @@ def test_sanity_gpt(
parallel_attention_mlp=parallel_attention_mlp,
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_gpt_126m():
......@@ -770,12 +599,10 @@ def test_sanity_gpt_126m():
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=True,
bias=True,
activation="gelu",
normalization="LayerNorm",
parallel_attention_mlp=False,
cpu_offload=False,
)
......@@ -783,19 +610,14 @@ def test_sanity_gpt_126m():
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -805,7 +627,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -814,7 +636,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
params_dtype=dtype,
apply_residual_connection_post_layernorm=True,
output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="causal",
normalization=normalization,
device="cuda",
......@@ -835,7 +656,6 @@ def test_sanity_bert_126m():
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm",
)
......@@ -844,19 +664,14 @@ def test_sanity_bert_126m():
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -866,7 +681,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -876,7 +691,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
device="cuda",
)
......@@ -896,7 +710,6 @@ def test_sanity_T5_126m():
fp8_recipe=fp8_recipe,
model="126m",
skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm",
)
......@@ -909,13 +722,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -925,7 +732,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -941,18 +748,11 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
def test_sanity_drop_path(dtype, fp8_recipe, model):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -962,7 +762,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -975,7 +775,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
device="cuda",
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
_test_sanity_e2e(block, dtype, config, fp8_recipe, False)
@pytest.mark.parametrize("dtype", param_types)
......@@ -986,13 +786,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -1002,7 +796,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -1015,27 +809,18 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
device="cuda",
)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -1045,7 +830,7 @@ def test_sanity_gradient_accumulation_fusion(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
......@@ -1054,7 +839,6 @@ def test_sanity_gradient_accumulation_fusion(
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
fuse_wgrad_accumulation=True,
device="cuda",
......@@ -1063,56 +847,6 @@ def test_sanity_gradient_accumulation_fusion(
_test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
if IS_HIP_EXTENSION:
if not use_hipblaslt():
pytest.skip("CUDA graph capture not supported with rocBLAS path")
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling():
pytest.skip("cuda graph not supported for float8_block_scaling recipe")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast():
a = torch.zeros((16, 16), device="cuda")
m = Linear(16, 32)
......@@ -1167,133 +901,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_enabled,
fp8_mha=False,
)
reset_rng_states()
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()
del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
"""Test the functionality of replace_raw_data"""
......@@ -1389,6 +996,32 @@ def test_sanity_checkpointing_on_callables():
torch.testing.assert_close(grad_checkpoint, grad_standard)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
"""Test that memory usage is optimized when weights are frozen for MXFP8."""
dim = 1024
linear = Linear(dim, dim, bias=False)
x = torch.randn(dim, dim, requires_grad=True, device="cuda")
# Freeze weights
linear.weight.requires_grad = False
# Forward and backward pass with FP8
with fp8_autocast():
o = linear(x)
g_o = torch.randn_like(o)
max_memory_before_backward = torch.cuda.max_memory_allocated()
o.backward(g_o)
max_memory_after_backward = torch.cuda.max_memory_allocated()
memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
assert memory_diff < 5.5, (
f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
" grad_output should be quantized only columnwise."
)
@pytest.mark.parametrize(
"module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
......
......@@ -4,12 +4,24 @@
from __future__ import annotations
import logging
import os
from contextlib import contextmanager
import pytest
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend,
AttentionParams,
AttentionLogging,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
......@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})")
# Cached RNG state
_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def reset_rng_states() -> None:
"""Revert to deterministic RNG state"""
global _rng_states
if _rng_states is None:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
_rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state())
else:
cpu_rng_state, cuda_rng_state = _rng_states
torch.set_rng_state(cpu_rng_state)
torch.cuda.set_rng_state(cuda_rng_state)
class ModelConfig:
def __init__(
self,
batch_size: int,
max_seqlen_q: int,
num_heads: int,
head_dim_qk: int,
max_seqlen_kv: int = None,
num_gqa_groups: int = None,
head_dim_v: int = None,
dropout_p: float = 0.0,
attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias",
alibi_type: str = "none",
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
eps: float = 1e-5,
):
self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv
self.num_heads = num_heads
self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
if self.head_dim_qk == self.head_dim_v:
self.kv_channels = self.head_dim_qk
else:
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
self.eps = eps
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def get_available_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check for all available attention backends that support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging()
with logging_context(highest_level=AttentionLogging._log_level):
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
......@@ -126,6 +126,7 @@ if(USE_CUDA)
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
......@@ -189,6 +190,7 @@ else()
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
......@@ -347,6 +349,8 @@ if(USE_CUDA)
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu)
make_string_header_from_file(utils.cuh
string_code_utils_cuh)
else()
......@@ -358,6 +362,8 @@ else()
string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu)
endif()
......@@ -385,6 +391,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
endif()
......
......@@ -246,6 +246,18 @@ def _load_cudnn():
if found:
return handle
# Attempt to locate libcudnn via ldconfig
libs = subprocess.check_output(
f"ldconfig -p | grep 'libcudnn{_get_sys_extension()}'", shell=True
)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcudnn" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
......@@ -267,12 +279,12 @@ def _load_nvrtc():
return handle
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = subprocess.check_output(
f"ldconfig -p | grep 'libnvrtc{_get_sys_extension()}'", shell=True
)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "stub" in lib or "libnvrtc-builtins" in lib:
continue
if "libnvrtc" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
......
......@@ -189,14 +189,26 @@ CommOverlapCore::~CommOverlapCore() {
if (_atomic_gemm) cudaFree(_counter.dptr());
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]);
for (size_t i = 0; i < _stream_compute.size(); i++) {
cudaStreamSynchronize(_stream_compute[i]);
cudaStreamDestroy(_stream_compute[i]);
}
auto error = cudaGetLastError();
if (error != cudaSuccess) {
NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error));
}
if (_comm_created) {
try {
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm);
destroy_communicator_mpi(_ub_comm);
#else
destroy_communicator(_ub_comm);
destroy_communicator(_ub_comm);
#endif
} catch (const std::exception &e) {
NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what());
}
_comm_created = false;
}
}
......@@ -382,6 +394,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
CommOverlapBase::~CommOverlapBase() {
cudaEventDestroy(_start_d2dcopy);
cudaStreamSynchronize(_stream_comm);
cudaStreamDestroy(_stream_comm);
}
......@@ -704,6 +717,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::split_overlap_rs
void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) {
int comm_bytes = _ubuf.bytes();
int comm_bytes_per_rank = comm_bytes / _tp_size;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
recv_stream);
for (auto stream : {send_stream, recv_stream}) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
// We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0));
}
}
/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
......
......@@ -2652,6 +2652,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
}
}
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
for (int j = 1; j < tp_size; j++) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * tp_rank;
int recv_offset = dstoffset + bytes_per_slice * tp_rank;
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
}
}
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
for (int j = tp_size - 1; j > 0; j--) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * i;
int recv_offset = dstoffset + bytes_per_slice * i;
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
}
}
// producer
static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) {
// Decrement atomic val to signal current output tile finish
......
......@@ -312,4 +312,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp
void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream);
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
......@@ -98,6 +98,9 @@ void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__
return;
#else
// Ensure the thread's "current" CUDA context is set.
cuda_driver::ensure_context_exists();
CUcontext ctx;
const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
switch (driver_status) {
......@@ -167,10 +170,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
(offset_elems * type_num_bits) / 8);
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment),
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT),
"Tensor data pointer must be 16B aligned");
const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits;
const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
"-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
......
......@@ -94,7 +94,7 @@ struct SimpleTensor {
nvte_make_shape(this->shape.data(), this->shape.size())};
}
int numel() const {
size_t numel() const {
size_t acc = 1;
for (const auto &dim : shape) {
acc *= dim;
......@@ -737,7 +737,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
// Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr int TMA_gmem_alignment = 16; // global memory address alignment
constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment
inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
......
......@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset) {
!requires_64bit_ragged_offset &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
......@@ -239,20 +241,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
// 9.10: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91000 &&
// 9.10.2: any head_dim + any arch + fprop + paged
// 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91002 &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 ||
(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) &&
// 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 &&
head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) &&
head_dim_qk != head_dim_v))) &&
// 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training &&
sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 &&
!(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 &&
......@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)))) &&
// check 64-bit ragged offset support
(supported_ragged_offset_size)) {
(supported_ragged_offset_size) &&
// 9.10.0/9.10.1: known bugs with SDPA F16
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......
......@@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id);
__syncwarp();
if (lane_id == 0) {
......@@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id);
__syncwarp();
if (lane_id == 0) {
......
......@@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
if (score_function == 0) {
if (topk > 1) {
auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, sum, lane_id);
auto sum_logits =
warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id);
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] = static_cast<DataType>(static_cast<double>(local_logits[i]) /
(static_cast<double>(sum_logits) + epsilon));
......@@ -231,13 +232,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int
*/
// Sigmoid Post-processing bwd when topk > 1
if (topk > 1 && score_function == 0) {
auto sum_fwd_input = warp_reduce_on_shmem(local_act_from_fwd, num_experts, sum, lane_id);
auto sum_fwd_input =
warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id);
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i];
}
__syncwarp();
auto sum_Output_x_Grad = warp_reduce_on_shmem(local_comp_buf, num_experts, sum, lane_id);
auto sum_Output_x_Grad =
warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id);
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] =
......
......@@ -220,7 +220,7 @@ __global__ void fused_topk_with_score_function_forward_kernel(
// score_function == 0 means sigmoid
if (score_function == 0) {
if (topk > 1) {
double sum_scores = warp_reduce_on_shmem(topk_scores, topk, sum, lane_id);
double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id);
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
topk_scores[i] = static_cast<double>(topk_scores[i]) / (sum_scores + epsilon);
}
......@@ -362,7 +362,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */ local_act_from_fwd,
/*mask ptr = */ local_routing_map,
/*data size = */ num_experts,
/*reduce func = */ sum, lane_id);
/*reduce func = */ ReduceFuncType::SUM, lane_id);
// Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = (local_routing_map[i] ? static_cast<double>(local_grad[i]) *
......@@ -374,7 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */ local_comp_buf,
/*mask ptr = */ local_routing_map,
/*data size = */ num_experts,
/*reduce func = */ sum, lane_id);
/*reduce func = */ ReduceFuncType::SUM, lane_id);
// In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (local_routing_map[i]) {
......
......@@ -30,14 +30,28 @@ __device__ inline T sum(T a, T b) {
return a + b;
}
enum ReduceFuncType {
SUM,
MAX,
};
template <typename T>
__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_func)(T, T),
__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type,
int lane_id) {
T (*reduce_func)(T, T);
double default_val = 0;
if (type == ReduceFuncType::SUM) {
reduce_func = sum;
default_val = 0;
} else if (type == ReduceFuncType::MAX) {
reduce_func = max;
default_val = -std::numeric_limits<double>::infinity();
}
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val =
lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : static_cast<double>(0);
volatile double val = lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : default_val;
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
val = reduce_func(val, data_ptr[i]);
}
......@@ -69,13 +83,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i
template <typename T>
__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size,
T (*reduce_func)(T, T), int lane_id) {
ReduceFuncType type, int lane_id) {
T (*reduce_func)(T, T);
double default_val = 0;
if (type == ReduceFuncType::SUM) {
reduce_func = sum;
default_val = 0;
} else if (type == ReduceFuncType::MAX) {
reduce_func = max;
default_val = -std::numeric_limits<double>::infinity();
}
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val = lane_id < data_size && mask[lane_id]
? static_cast<double>(data_ptr[lane_id])
: static_cast<double>(0);
volatile double val =
lane_id < data_size && mask[lane_id] ? static_cast<double>(data_ptr[lane_id]) : default_val;
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
if (mask[i]) {
val = reduce_func(val, data_ptr[i]);
......@@ -128,7 +151,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
float sum_Output_x_Grad = warp_reduce_on_shmem(
/*data ptr = */ comp_buf,
/*data size = */ data_size,
/*reduce func = */ sum, lane_id);
/*reduce func = */ ReduceFuncType::SUM, lane_id);
// In-place update
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
if (mask) {
......@@ -147,14 +170,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
template <typename DataType>
__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) {
// 1. compute the max of value
float max_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, max, lane_id));
float max_val =
static_cast<float>(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id));
// 2. value -> exp_value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(exp(static_cast<float>(scores[i]) - max_val));
}
__syncwarp();
// 3. compute the sum of exp_value
float sum_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, sum, lane_id));
float sum_val =
static_cast<float>(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id));
// 4. update the softmax value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(scores[i]) / sum_val;
......@@ -165,19 +190,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i
template <typename T>
__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices,
T *topk_scores, int lane_id) {
// Check if the index is masked by the later iteration
auto is_masked = [&topk_indices](int k, int index) {
if (k == 0) return false;
for (int i = 0; i < k; i++) {
if (topk_indices[i] == index) return true;
}
return false;
};
// Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices
for (int k = 0; k < topk; k++) {
// Find the max value and its index
volatile double val =
(lane_id < data_size) ? static_cast<double>(scores[lane_id]) : static_cast<double>(0);
volatile double val = (lane_id < data_size && !is_masked(k, lane_id))
? static_cast<double>(scores[lane_id])
: -std::numeric_limits<double>::infinity();
volatile int index = (lane_id < data_size) ? lane_id : 0;
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
volatile double cur_val = scores[i];
volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits<double>::infinity()
: static_cast<double>(scores[i]);
if (cur_val > val) {
val = cur_val;
index = i;
......@@ -200,17 +235,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
if (lane_id == 0) {
topk_indices[k] = index;
topk_scores[k] = val;
scores[index] =
static_cast<double>(-1.0) - val; // make the selected experts using val = - 1 - val
}
__syncwarp();
}
// Reset the scores to the original value
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
scores[topk_indices[i]] =
static_cast<double>(-1.0) - static_cast<double>(scores[topk_indices[i]]);
}
}
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
......
......@@ -253,8 +253,9 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) {
float alpha, float beta, bool use_split_accumulator, int math_sm_count,
int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
cudaStream_t stream) {
// Tensor dims in row-major order
const int A0 = inputA->flat_first_dim();
const int A1 = inputA->flat_last_dim();
......@@ -310,13 +311,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"fp8 Aux output for gemm + gelu fusion not supported!");
}
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!");
}
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
cublasLtMatmulDesc_t operationDesc = nullptr;
......@@ -601,7 +598,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
static_cast<const void *>(&one), /* alpha */
static_cast<const void *>(&alpha), /* alpha */
param.A, /* A */
Adesc, param.B, /* B */
Bdesc, static_cast<const void *>(&beta), /* beta */
......@@ -752,8 +749,27 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
#endif //__HIP_PLATFORM_AMD__
1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false,
nullptr, stream);
#endif //__HIP_PLATFORM_AMD__
}
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
bool use_split_accumulator, int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm_scaled);
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
}
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
......@@ -846,8 +862,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream);
1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split,
n_split, gemm_producer, inputCounter, stream);
#endif //__HIP_PLATFORM_AMD__
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment