Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from collections import OrderedDict
import math import math
import os import os
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
...@@ -39,54 +38,39 @@ from transformer_engine.pytorch import ( ...@@ -39,54 +38,39 @@ from transformer_engine.pytorch import (
Fp8Unpadding, Fp8Unpadding,
) )
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
seed = 1234 seed = 1234
torch.manual_seed(seed) # Reset RNG states.
torch.cuda.manual_seed(seed) reset_rng_states()
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
if torch_version() >= (2, 7, 0): if torch_version() >= (2, 7, 0):
torch._dynamo.config.recompile_limit = 16 torch._dynamo.config.recompile_limit = 16
else: else:
torch._dynamo.config.cache_size_limit = 16 torch._dynamo.config.cache_size_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
self.hidden_size = hidden_size
self.eps = eps
self.num_attention_heads = num_attention_heads
self.embed = embed
self.num_layers = num_layers
self.seq_len = seq_len
model_configs = { model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128), "small": ModelConfig(1, 128, 8, 16, num_layers=4),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), "126m": ModelConfig(1, 2048, 12, 64, num_layers=12),
} }
model_configs_inference = { model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len "126m": ModelConfig(1, 256, 12, 64, num_layers=12),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
} }
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"]
...@@ -120,12 +104,27 @@ if NVTE_TEST_NVINSPECT_ENABLED: ...@@ -120,12 +104,27 @@ if NVTE_TEST_NVINSPECT_ENABLED:
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
) )
fp8_recipes = [
recipe.MXFP8BlockScaling(), fp8_recipes = []
recipe.DelayedScaling(), if mxfp8_available:
recipe.Float8CurrentScaling(), fp8_recipes.append(recipe.MXFP8BlockScaling())
recipe.Float8BlockScaling(), if fp8_block_scaling_available:
] fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
):
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
)
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
def get_causal_attn_mask(sq: int) -> torch.Tensor: def get_causal_attn_mask(sq: int) -> torch.Tensor:
...@@ -177,12 +176,6 @@ def assert_allclose( ...@@ -177,12 +176,6 @@ def assert_allclose(
raise AssertionError(msg) raise AssertionError(msg)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_global_fp8_state(): def reset_global_fp8_state():
yield yield
...@@ -535,13 +528,13 @@ def _test_e2e_selective_recompute( ...@@ -535,13 +528,13 @@ def _test_e2e_selective_recompute(
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.embed, kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
params_dtype=dtype, params_dtype=dtype,
...@@ -550,13 +543,13 @@ def _test_e2e_selective_recompute( ...@@ -550,13 +543,13 @@ def _test_e2e_selective_recompute(
) )
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
te_out = block( te_out = block(
...@@ -582,14 +575,8 @@ def _test_e2e_selective_recompute( ...@@ -582,14 +575,8 @@ def _test_e2e_selective_recompute(
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -630,13 +617,13 @@ def _test_e2e_full_recompute( ...@@ -630,13 +617,13 @@ def _test_e2e_full_recompute(
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.embed, kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
params_dtype=dtype, params_dtype=dtype,
...@@ -645,14 +632,14 @@ def _test_e2e_full_recompute( ...@@ -645,14 +632,14 @@ def _test_e2e_full_recompute(
) )
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=use_reentrant, requires_grad=use_reentrant,
) )
if use_reentrant: if use_reentrant:
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if recompute: if recompute:
...@@ -698,14 +685,8 @@ def _test_e2e_full_recompute( ...@@ -698,14 +685,8 @@ def _test_e2e_full_recompute(
def test_gpt_full_activation_recompute( def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -761,13 +742,13 @@ def _test_e2e_checkpointing_get_model(config, dtype): ...@@ -761,13 +742,13 @@ def _test_e2e_checkpointing_get_model(config, dtype):
return TransformerLayer( return TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.embed, kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
params_dtype=dtype, params_dtype=dtype,
...@@ -779,7 +760,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -779,7 +760,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
reset_rng_states() reset_rng_states()
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
...@@ -809,14 +790,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -809,14 +790,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if p.requires_grad: if p.requires_grad:
param_grads.append(p.grad.clone()) param_grads.append(p.grad.clone())
global _cpu_rng_state, _cuda_rng_state
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
del block del block
block = _test_e2e_checkpointing_get_model(config, dtype) block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path, weights_only=False)) block.load_state_dict(torch.load(path, weights_only=False))
reset_rng_states() torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
for p in block.parameters(): for p in block.parameters():
if p.requires_grad: if p.requires_grad:
...@@ -849,6 +830,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -849,6 +830,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model): def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(config, dtype):
pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
...@@ -869,13 +852,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): ...@@ -869,13 +852,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states() reset_rng_states()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
out = block(inp_hidden_states, attention_mask=inp_attn_mask) out = block(inp_hidden_states, attention_mask=inp_attn_mask)
loss = out.sum() loss = out.sum()
...@@ -895,11 +878,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): ...@@ -895,11 +878,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
pytest.skip("No attention backend available.")
te_gpt = TransformerLayer( te_gpt = TransformerLayer(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size, ffn_hidden_size=4 * config.hidden_size,
num_attention_heads=config.num_attention_heads, num_attention_heads=config.num_heads,
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
attention_dropout=0.1, attention_dropout=0.1,
hidden_dropout=0.1, hidden_dropout=0.1,
...@@ -914,7 +899,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): ...@@ -914,7 +899,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
TorchGPT( TorchGPT(
config.hidden_size, config.hidden_size,
config.eps, config.eps,
config.num_attention_heads, config.num_heads,
parallel_attention_mlp=parallel_attention_mlp, parallel_attention_mlp=parallel_attention_mlp,
) )
.to(dtype=dtype) .to(dtype=dtype)
...@@ -975,13 +960,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): ...@@ -975,13 +960,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states() reset_rng_states()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None
forward_kwargs = {} forward_kwargs = {}
if te: if te:
...@@ -1006,10 +991,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): ...@@ -1006,10 +991,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types) @pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type): def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
pytest.skip("No attention backend available.")
te_mha = MultiheadAttention( te_mha = MultiheadAttention(
config.hidden_size, config.hidden_size,
config.num_attention_heads, config.num_heads,
fuse_qkv_params=True, fuse_qkv_params=True,
params_dtype=dtype, params_dtype=dtype,
qkv_weight_interleaved=False, qkv_weight_interleaved=False,
...@@ -1020,7 +1007,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -1020,7 +1007,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha = ( torch_mha = (
TorchMHA( TorchMHA(
config.hidden_size, config.hidden_size,
config.num_attention_heads, config.num_heads,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -1066,7 +1053,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, ...@@ -1066,7 +1053,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
...@@ -1098,11 +1085,12 @@ def _test_dpa_accuracy(block, bs, dtype, config): ...@@ -1098,11 +1085,12 @@ def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states() reset_rng_states()
mask = torch.triu( mask = torch.triu(
torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1 torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"),
diagonal=1,
) )
query, key, value = [ query, key, value = [
torch.randn( torch.randn(
(config.seq_len, bs, config.num_attention_heads, config.embed), (config.max_seqlen_q, bs, config.num_heads, config.kv_channels),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
...@@ -1131,8 +1119,8 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -1131,8 +1119,8 @@ def test_dpa_accuracy(dtype, bs, model):
te_dpa = ( te_dpa = (
DotProductAttention( DotProductAttention(
config.num_attention_heads, config.num_heads,
config.embed, config.kv_channels,
attention_dropout=0.0, # disable dropout, FU uses rng differently attention_dropout=0.0, # disable dropout, FU uses rng differently
) )
.to(dtype=dtype) .to(dtype=dtype)
...@@ -1141,7 +1129,7 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -1141,7 +1129,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa = ( torch_dpa = (
TorchDotProductAttention( TorchDotProductAttention(
config.embed, config.kv_channels,
0.0, # dropout 0.0, # dropout
) )
.to(dtype=dtype) .to(dtype=dtype)
...@@ -1267,8 +1255,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ...@@ -1267,8 +1255,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
te_linear_ref, bs, dtype, config, delay_wgrad_compute=False te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
) )
# Shoule be bit-wise match # Should be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): for _, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
...@@ -1280,17 +1268,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1280,17 +1268,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation = True fuse_wgrad_accumulation = True
fp8_model_params = False fp8_model_params = False
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
...@@ -1653,14 +1636,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret ...@@ -1653,14 +1636,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", [2])
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute( def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation dtype, bs, model, bias, fuse_wgrad_accumulation
): ):
config = model_configs[model] config = model_configs[model]
...@@ -1669,7 +1650,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1669,7 +1650,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size=4 * config.hidden_size, ffn_hidden_size=4 * config.hidden_size,
eps=config.eps, eps=config.eps,
bias=bias, bias=bias,
normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
delay_wgrad_compute=True, delay_wgrad_compute=True,
...@@ -1681,7 +1661,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1681,7 +1661,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
ffn_hidden_size=4 * config.hidden_size, ffn_hidden_size=4 * config.hidden_size,
eps=config.eps, eps=config.eps,
bias=bias, bias=bias,
normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
delay_wgrad_compute=False, delay_wgrad_compute=False,
...@@ -1691,8 +1670,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ...@@ -1691,8 +1670,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(
# Share params # Share params
with torch.no_grad(): with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
if normalization != "RMSNorm": ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
if bias: if bias:
...@@ -1730,7 +1708,7 @@ def _test_grouped_linear_accuracy( ...@@ -1730,7 +1708,7 @@ def _test_grouped_linear_accuracy(
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
...@@ -1743,14 +1721,14 @@ def _test_grouped_linear_accuracy( ...@@ -1743,14 +1721,14 @@ def _test_grouped_linear_accuracy(
split_size = 16 split_size = 16
if recipe.mxfp8(): if recipe.mxfp8():
split_size = 128 split_size = 128
m = config.seq_len // split_size m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * split_size m_splits = m_splits * split_size
assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms
else: else:
m_splits = torch.tensor([config.seq_len]) m_splits = torch.tensor([config.max_seqlen_q])
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, GroupedLinear): if isinstance(block, GroupedLinear):
...@@ -1806,17 +1784,11 @@ def test_grouped_linear_accuracy( ...@@ -1806,17 +1784,11 @@ def test_grouped_linear_accuracy(
parallel_mode=None, parallel_mode=None,
): ):
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
...@@ -1908,19 +1880,13 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -1908,19 +1880,13 @@ def test_grouped_linear_accuracy_save_original_input(
parallel_mode=None, parallel_mode=None,
): ):
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
...@@ -2074,14 +2040,14 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r ...@@ -2074,14 +2040,14 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len * bs, config.hidden_size), (config.max_seqlen_q * bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe): with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, TorchGroupedLinearWithPadding): if isinstance(block, TorchGroupedLinearWithPadding):
...@@ -2124,17 +2090,11 @@ def test_padding_grouped_linear_accuracy( ...@@ -2124,17 +2090,11 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params, fp8_model_params,
parallel_mode=None, parallel_mode=None,
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
...@@ -2199,19 +2159,13 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2199,19 +2159,13 @@ def test_padding_grouped_linear_accuracy_save_original_input(
fp8_model_params, fp8_model_params,
parallel_mode=None, parallel_mode=None,
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
...@@ -2268,9 +2222,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): ...@@ -2268,9 +2222,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
# Placeholders used for graph capture. # Placeholders used for graph capture.
static_input = torch.randn( static_input = torch.randn(
config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
)
static_target = torch.randn(
config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype
) )
static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
real_input = torch.rand_like(static_input) real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target) real_target = torch.rand_like(static_target)
...@@ -2334,7 +2290,7 @@ def test_gpt_cuda_graph(dtype, bs, model): ...@@ -2334,7 +2290,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
block_args = ( block_args = (
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
) )
block_kwargs = dict( block_kwargs = dict(
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
...@@ -2342,7 +2298,7 @@ def test_gpt_cuda_graph(dtype, bs, model): ...@@ -2342,7 +2298,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.embed, kv_channels=config.kv_channels,
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
...@@ -2377,13 +2333,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2377,13 +2333,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
attention_dropout=0.1, attention_dropout=0.1,
kv_channels=config.embed, kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
params_dtype=dtype, params_dtype=dtype,
...@@ -2392,13 +2348,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2392,13 +2348,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
) )
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = get_causal_attn_mask(config.seq_len) te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=True, fp8_recipe=recipe): with fp8_autocast(enabled=True, fp8_recipe=recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
...@@ -2418,14 +2374,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2418,14 +2374,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -2461,13 +2411,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -2461,13 +2411,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_sbhd = TransformerLayer( block_sbhd = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0, hidden_dropout=0,
attention_dropout=0, attention_dropout=0,
kv_channels=config.embed, kv_channels=config.kv_channels,
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
...@@ -2482,13 +2432,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -2482,13 +2432,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_bshd = TransformerLayer( block_bshd = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0, hidden_dropout=0,
attention_dropout=0, attention_dropout=0,
kv_channels=config.embed, kv_channels=config.kv_channels,
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
...@@ -2500,13 +2450,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -2500,13 +2450,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_thd = TransformerLayer( block_thd = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0, hidden_dropout=0,
attention_dropout=0, attention_dropout=0,
kv_channels=config.embed, kv_channels=config.kv_channels,
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
...@@ -2521,15 +2471,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -2521,15 +2471,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
x_sbhd = torch.randn( x_sbhd = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
x_bshd = x_sbhd.transpose(0, 1).contiguous() x_bshd = x_sbhd.transpose(0, 1).contiguous()
x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous()
x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q
# To make sure forward is also identical (just in case some module decides # To make sure forward is also identical (just in case some module decides
# to act fancy) # to act fancy)
...@@ -2556,167 +2506,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -2556,167 +2506,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
x_thd, x_thd,
cu_seqlens_q=x_thd_cumsum, cu_seqlens_q=x_thd_cumsum,
cu_seqlens_kv=x_thd_cumsum, cu_seqlens_kv=x_thd_cumsum,
max_seqlen_q=config.seq_len, max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.seq_len, max_seqlen_kv=config.max_seqlen_kv,
) )
torch.testing.assert_close( torch.testing.assert_close(
y_bshd, y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention"]:
pytest.skip("Not support FusedAttention")
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
if (
backend == "FusedAttention"
and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 12, 0)
):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
config = model_configs_inference[model_key]
S = config.seq_len
B = bs
H = config.num_attention_heads
D = config.hidden_size
head_size = config.embed
layer_number = 1
# Limits the max size of KV-cache
B_max = B
S_max = S
if module == "TransformerLayer":
model = TransformerLayer(
hidden_size=D,
ffn_hidden_size=4 * D,
num_attention_heads=H,
attn_input_format=input_format,
self_attn_mask_type="causal",
enc_dec_attn_mask_type="causal",
layer_number=layer_number,
attention_dropout=0.0,
params_dtype=dtype,
device="cuda",
).eval()
else:
model = (
MultiheadAttention(
hidden_size=D,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout=0.0,
attn_mask_type="causal",
params_dtype=dtype,
)
.cuda()
.eval()
)
inference_params = InferenceParams(
max_batch_size=B_max,
max_sequence_length=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
is_paged=is_paged,
total_num_pages=int(B_max * S_max / 256),
page_size=256,
)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
if input_format == "bshd":
input = input.transpose(0, 1).contiguous()
incremental_output = torch.zeros_like(input)
# Generate output for the entire sequence
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
step_dict = OrderedDict(zip(list(range(B)), [1] * B))
for i in range(S):
inference_params.pre_step(step_dict)
if input_format == "sbhd":
incremental_input = input[i].view(1, B, D)
else:
incremental_input = input[:, i, :].view(B, 1, D)
seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
mask_type = "padding"
kwargs = {}
if module == "TransformerLayer":
kwargs["self_attn_mask_type"] = mask_type
else:
kwargs["attn_mask_type"] = mask_type
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None,
**kwargs,
max_seqlen_q=1,
max_seqlen_kv=S,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
) )
if input_format == "sbhd":
incremental_output[i, :, :] = line_output.view(B, D)
else:
incremental_output[:, i, :] = line_output.view(B, D)
if module == "TransformerLayer":
atol = {
torch.float32: 5e-3,
torch.half: 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
torch.float32: 1e-3,
torch.half: 1e-3,
torch.bfloat16: 1e-2,
}
# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"shape", "shape",
...@@ -2815,9 +2613,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): ...@@ -2815,9 +2613,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
(16, 4096, 128, 512), (16, 4096, 128, 512),
], ],
) )
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("accumulate", [False, True]) @pytest.mark.parametrize("accumulate", [False, True])
def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): def test_fp8_grouped_gemm(shape, accumulate):
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
......
...@@ -27,7 +27,6 @@ import warnings ...@@ -27,7 +27,6 @@ import warnings
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import torch import torch
import random
from torch import nn as nn from torch import nn as nn
from typing import Optional, Union, Tuple, List from typing import Optional, Union, Tuple, List
from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
...@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) ...@@ -59,14 +58,13 @@ TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
fp8_recipes = [ fp8_recipes = []
None, if mxfp8_available:
recipe.DelayedScaling(), fp8_recipes.append(recipe.MXFP8BlockScaling())
recipe.MXFP8BlockScaling(), if fp8_available:
] fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
...@@ -369,14 +367,6 @@ def validate_result( ...@@ -369,14 +367,6 @@ def validate_result(
) )
def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor
return meta
def dtype2str(dtype: torch.dtype, fake_bf16_io=False): def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
if fake_bf16_io: if fake_bf16_io:
assert dtype == torch.bfloat16 assert dtype == torch.bfloat16
...@@ -413,36 +403,12 @@ Test cases begin here. ...@@ -413,36 +403,12 @@ Test cases begin here.
""" """
@pytest.mark.parametrize("scale_factor", [112]) def _test_export_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) fp8_recipe: recipe.Recipe = fp8_recipes[0],
# Returning the bias is a TE fusion optimization we don't care about. use_bias: bool = True,
@pytest.mark.parametrize("return_bias", [True, False]) return_bias: bool = False,
@pytest.mark.parametrize( precision: torch.dtype = torch.float32,
"precision, use_bias",
[
(torch.float32, False),
(torch.float32, True),
(torch.float16, False),
(torch.float16, True),
# Todo: cannot configure BF16 when bias is disabled (ORT issue?)
(torch.bfloat16, False),
# Todo: cannot configure BF16 when bias is enabled (ORT issue?)
(torch.bfloat16, True),
],
)
def test_export_linear(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool,
return_bias: bool,
precision: torch.dtype,
): ):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias: if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled") pytest.skip("Cannot return bias when bias is disabled")
...@@ -498,32 +464,28 @@ def test_export_linear( ...@@ -498,32 +464,28 @@ def test_export_linear(
) )
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize( @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
"precision", def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
[ _test_export_linear(fp8_recipe=fp8_recipe, precision=precision)
torch.float32,
torch.float16,
torch.bfloat16, @pytest.mark.parametrize("use_bias", [True, False])
], def test_export_linear_use_bias(seed_default_rng, use_bias):
) _test_export_linear(use_bias=use_bias)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm(
seed_default_rng,
scale_factor: float,
fp8_recipe: recipe.Recipe,
precision: torch.dtype,
zero_centered_gamma: bool,
normalization: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
@pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias):
_test_export_linear(return_bias=return_bias)
def _test_export_layernorm(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
normalization: str = all_normalizations[0],
):
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
batch_size = 4 batch_size = 4
in_features = 64 in_features = 64
...@@ -564,39 +526,31 @@ def test_export_layernorm( ...@@ -564,39 +526,31 @@ def test_export_layernorm(
) )
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [True, False]) def test_export_layernorm_recipe(seed_default_rng, fp8_recipe, precision):
@pytest.mark.parametrize( _test_export_layernorm(fp8_recipe=fp8_recipe, precision=precision)
"precision, use_bias",
[
(torch.float32, False), def test_export_layernorm_zero_centered_gamma(seed_default_rng):
(torch.float32, True), _test_export_layernorm(zero_centered_gamma=True)
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True),
(torch.bfloat16, False),
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear( def test_export_layernorm_normalization(seed_default_rng, normalization):
seed_default_rng, _test_export_layernorm(normalization=normalization)
scale_factor: float,
fp8_recipe: recipe.Recipe,
use_bias: bool, def _test_export_layernorm_linear(
return_bias: bool, scale_factor: float = 112,
return_layernorm_output: bool, fp8_recipe: recipe.Recipe = fp8_recipes[0],
precision: torch.dtype, use_bias: bool = True,
zero_centered_gamma: bool, return_bias: bool = False,
normalization: str, return_layernorm_output: bool = False,
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
normalization: str = all_normalizations[0],
): ):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias: if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled") pytest.skip("Cannot return bias when bias is disabled")
...@@ -644,41 +598,44 @@ def test_export_layernorm_linear( ...@@ -644,41 +598,44 @@ def test_export_layernorm_linear(
) )
@pytest.mark.parametrize("scale_factor", [112])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("return_bias", [True, False]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [True, False]) def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
@pytest.mark.parametrize( _test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)
"precision, use_bias",
[
(torch.float32, False), def test_export_layernorm_linear_return_ln_out(seed_default_rng):
(torch.float32, True), _test_export_layernorm_linear(return_layernorm_output=True)
(torch.float16, True),
(torch.float16, False),
(torch.bfloat16, True), def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
(torch.bfloat16, False), _test_export_layernorm_linear(zero_centered_gamma=True)
],
)
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("normalization", all_normalizations[1:])
@pytest.mark.parametrize("activation", supported_activations) def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
@pytest.mark.parametrize("normalization", all_normalizations) _test_export_layernorm_linear(normalization=normalization)
def test_export_layernorm_mlp(
seed_default_rng,
scale_factor: float, def test_export_layernorm_linear_no_bias(seed_default_rng):
fp8_recipe: recipe.Recipe, _test_export_layernorm_linear(use_bias=False)
use_bias: bool,
return_bias: bool,
return_layernorm_output: bool, def test_export_layernorm_linear_return_bias(seed_default_rng):
precision: torch.dtype, _test_export_layernorm_linear(return_bias=True)
zero_centered_gamma: bool,
activation: str,
normalization: str, def _test_export_layernorm_mlp(
scale_factor: float = 112,
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_bias: bool = True,
return_bias: bool = False,
return_layernorm_output: bool = False,
precision: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
activation: str = supported_activations[0],
normalization: str = all_normalizations[0],
): ):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if return_bias and not use_bias: if return_bias and not use_bias:
pytest.skip("Cannot return bias when bias is disabled") pytest.skip("Cannot return bias when bias is disabled")
...@@ -720,6 +677,38 @@ def test_export_layernorm_mlp( ...@@ -720,6 +677,38 @@ def test_export_layernorm_mlp(
) )
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
_test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
_test_export_layernorm_mlp(return_layernorm_output=True)
def test_export_layernorm_mlp_return_bias(seed_default_rng):
_test_export_layernorm_mlp(return_bias=True)
def test_export_layernorm_mlp_no_bias(seed_default_rng):
_test_export_layernorm_mlp(use_bias=False)
def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
_test_export_layernorm_mlp(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
_test_export_layernorm_mlp(normalization=normalization)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", "precision, use_mask, attn_mask_type",
[ [
...@@ -734,8 +723,6 @@ def test_export_layernorm_mlp( ...@@ -734,8 +723,6 @@ def test_export_layernorm_mlp(
], ],
) )
def test_export_core_attention( def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
precision: torch.dtype, precision: torch.dtype,
use_mask: bool, use_mask: bool,
attn_mask_type: str, attn_mask_type: str,
...@@ -777,11 +764,6 @@ def test_export_core_attention( ...@@ -777,11 +764,6 @@ def test_export_core_attention(
) )
test_configs_multihead_attention = [
# "use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledSoftmax
(True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [ test_configs_attention_type = [
# "input_layernorm, attention_type, fuse_qkv_params" # "input_layernorm, attention_type, fuse_qkv_params"
(True, "self", True), (True, "self", True),
...@@ -795,31 +777,14 @@ test_configs_attention_type = [ ...@@ -795,31 +777,14 @@ test_configs_attention_type = [
] ]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) def _test_export_multihead_attention(
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) fp8_recipe: recipe.Recipe = fp8_recipes[0],
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) use_mask: bool = True,
@pytest.mark.parametrize("return_layernorm_output", [False]) precision: torch.dtype = torch.float32,
@pytest.mark.parametrize( input_layernorm: bool = True,
"input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type attention_type: str = "self",
) fuse_qkv_params: bool = True,
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str,
precision: torch.dtype,
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
fuse_qkv_params: bool,
): ):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
hidden_size = 256 hidden_size = 256
sequence_length = 128 sequence_length = 128
batch_size = 4 batch_size = 4
...@@ -837,6 +802,7 @@ def test_export_multihead_attention( ...@@ -837,6 +802,7 @@ def test_export_multihead_attention(
init_method, init_method,
output_layer_init_method, output_layer_init_method,
) )
attn_mask_type = "arbitrary" if use_mask else "no_mask"
hidden_states_context = torch.randn( hidden_states_context = torch.randn(
sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
...@@ -868,7 +834,7 @@ def test_export_multihead_attention( ...@@ -868,7 +834,7 @@ def test_export_multihead_attention(
*attention_args, *attention_args,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
params_dtype=precision, params_dtype=precision,
return_layernorm_output=return_layernorm_output, return_layernorm_output=False,
input_layernorm=input_layernorm, input_layernorm=input_layernorm,
attention_type=attention_type, attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
...@@ -960,30 +926,37 @@ def test_export_multihead_attention( ...@@ -960,30 +926,37 @@ def test_export_multihead_attention(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("output_layernorm", [True, False])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True]) def test_export_multihead_attention_recipe(fp8_recipe, precision):
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) _test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)
@pytest.mark.parametrize("activation", supported_activations)
def test_export_transformer_layer(
seed_default_rng, def test_export_multihead_attention_no_mask():
set_max_seq_len, _test_export_multihead_attention(use_mask=False)
fp8_recipe: recipe.Recipe,
use_mask: bool,
attn_mask_type: str, def test_export_multihead_attention_no_input_layernorm():
output_layernorm: bool, _test_export_multihead_attention(input_layernorm=False)
precision: torch.dtype,
fuse_qkv_params: bool,
zero_centered_gamma: bool,
activation: str,
):
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
def test_export_multihead_attention_cross_attn():
_test_export_multihead_attention(attention_type="cross")
def test_export_multihead_attention_unfused_qkv_params():
_test_export_multihead_attention(fuse_qkv_params=False)
def _test_export_transformer_layer(
fp8_recipe: recipe.Recipe = fp8_recipes[0],
use_mask: bool = True,
attn_mask_type: str = "arbitrary",
output_layernorm: bool = False,
precision: torch.dtype = torch.float32,
fuse_qkv_params: bool = True,
zero_centered_gamma: bool = False,
activation: str = supported_activations[0],
):
# Layer configuration # Layer configuration
hidden_size = 64 hidden_size = 64
sequence_length = 128 sequence_length = 128
...@@ -1043,28 +1016,43 @@ def test_export_transformer_layer( ...@@ -1043,28 +1016,43 @@ def test_export_transformer_layer(
) )
@skip_FP8 @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@skip_MXFP8 @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision):
_test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)
def test_export_transformer_layer_no_mask():
_test_export_transformer_layer(use_mask=False)
def test_export_transformer_layer_output_layernorm():
_test_export_transformer_layer(output_layernorm=True)
def test_export_transformer_layer_unfused_qkv_params():
_test_export_transformer_layer(fuse_qkv_params=False)
def test_export_transformer_layer_zero_centered_gamma():
_test_export_transformer_layer(zero_centered_gamma=True)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation):
_test_export_transformer_layer(activation=activation)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [True])
def test_export_gpt_generation( def test_export_gpt_generation(
seed_default_rng,
set_max_seq_len,
fp8_recipe: recipe.Recipe, fp8_recipe: recipe.Recipe,
precision: torch.dtype, precision: torch.dtype,
zero_centered_gamma: bool,
): ):
"""Test that the ONNX model can correctly handle inputs with different shapes and that """Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths. the attention mask is adjusted on-the-fly to different sequence lengths.
""" """
# Skip FP8 tests on non-hopper devices
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Layer configuration # Layer configuration
hidden_size = 64 hidden_size = 64
sequence_length = 128 sequence_length = 128
...@@ -1091,7 +1079,6 @@ def test_export_gpt_generation( ...@@ -1091,7 +1079,6 @@ def test_export_gpt_generation(
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
).to(device="cuda") ).to(device="cuda")
# "Context phase": use full input sequence length # "Context phase": use full input sequence length
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# See LICENSE for license information. # See LICENSE for license information.
import random import random
import pytest
import torch import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
......
...@@ -8,10 +8,10 @@ import pytest ...@@ -8,10 +8,10 @@ import pytest
import torch import torch
@pytest.mark.parametrize("use_qk_norm", [False, True]) @pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"])
@pytest.mark.parametrize("attention_type", ["self", "cross"]) @pytest.mark.parametrize("attention_type", ["self", "cross"])
@pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5]) @pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5])
def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None: def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None:
"""Test QK normalization functionality, module structure, and numerical behavior.""" """Test QK normalization functionality, module structure, and numerical behavior."""
hidden_size = 256 hidden_size = 256
num_attention_heads = 8 num_attention_heads = 8
...@@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None ...@@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_type=attention_type, attention_type=attention_type,
use_qk_norm=use_qk_norm, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
# Check module structure based on use_qk_norm parameter # Check module structure based on qk_norm_type parameter
if use_qk_norm: if qk_norm_type is not None:
assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True" assert mha.q_norm is not None, "Should have q_norm module when qk_norm_type is not None"
assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module" assert mha.k_norm is not None, "Should have k_norm module when qk_norm_type is not None"
assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module"
# Check that the module is L2Norm type # Check that the modules are of the correct type
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization if qk_norm_type == "L2Normalization":
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
assert isinstance(
mha.qk_norm, L2Normalization assert isinstance(
), "qk_norm should be an L2Normalization module" mha.q_norm, L2Normalization
), "q_norm should be an L2Normalization module"
assert isinstance(
mha.k_norm, L2Normalization
), "k_norm should be an L2Normalization module"
# For L2 normalization, q_norm and k_norm should be the same instance (parameter-free)
assert (
mha.q_norm is mha.k_norm
), "q_norm and k_norm should be the same instance for L2 normalization"
elif qk_norm_type == "RMSNorm":
from transformer_engine.pytorch.module.rmsnorm import RMSNorm
assert isinstance(mha.q_norm, RMSNorm), "q_norm should be an RMSNorm module"
assert isinstance(mha.k_norm, RMSNorm), "k_norm should be an RMSNorm module"
# For RMS normalization, q_norm and k_norm should be separate instances
assert (
mha.q_norm is not mha.k_norm
), "q_norm and k_norm should be separate instances for RMS normalization"
elif qk_norm_type == "LayerNorm":
from transformer_engine.pytorch.module.layernorm import LayerNorm
assert isinstance(mha.q_norm, LayerNorm), "q_norm should be a LayerNorm module"
assert isinstance(mha.k_norm, LayerNorm), "k_norm should be a LayerNorm module"
# For LayerNorm, q_norm and k_norm should be separate instances
assert (
mha.q_norm is not mha.k_norm
), "q_norm and k_norm should be separate instances for LayerNorm"
else:
# For extensibility - just ensure they exist
assert mha.q_norm is not None, f"q_norm should exist for qk_norm_type={qk_norm_type}"
assert mha.k_norm is not None, f"k_norm should exist for qk_norm_type={qk_norm_type}"
else: else:
assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False" assert mha.q_norm is None, "Should not have q_norm module when qk_norm_type is None"
assert mha.k_norm is None, "Should not have k_norm module when qk_norm_type is None"
# Create input tensors # Create input tensors
batch_size = 2 # Use a fixed batch size for testing batch_size = 2 # Use a fixed batch size for testing
...@@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None ...@@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None
assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf" assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf"
def test_qk_norm_output_difference() -> None: @pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_output_difference(qk_norm_type) -> None:
"""Test that QK normalization actually changes the output compared to no normalization.""" """Test that QK normalization actually changes the output compared to no normalization."""
hidden_size = 256 hidden_size = 256
num_attention_heads = 8 num_attention_heads = 8
seq_len = 128 seq_len = 128
batch_size = 2 batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization # Reset to a known seed for reproducible initialization
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
...@@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None: ...@@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None:
mha_with_norm = MultiheadAttention( mha_with_norm = MultiheadAttention(
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_qk_norm=True, qk_norm_type=qk_norm_type,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None: ...@@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None:
mha_no_norm = MultiheadAttention( mha_no_norm = MultiheadAttention(
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_qk_norm=False, qk_norm_type=None,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None: ...@@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None:
# Outputs should be different when QK normalization is enabled # Outputs should be different when QK normalization is enabled
assert not torch.allclose( assert not torch.allclose(
output_with_norm, output_no_norm, atol=1e-6 output_with_norm, output_no_norm, atol=1e-6
), "QK normalization should change the output, but outputs are identical" ), f"QK normalization ({qk_norm_type}) should change the output, but outputs are identical"
def test_qk_norm_with_fused_qkv() -> None: @pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_with_fused_qkv(qk_norm_type) -> None:
"""Test QK normalization works with fused QKV parameters.""" """Test QK normalization works with fused QKV parameters."""
hidden_size = 256 hidden_size = 256
num_attention_heads = 8 num_attention_heads = 8
...@@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None: ...@@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None:
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
fuse_qkv_params=True, fuse_qkv_params=True,
use_qk_norm=True, qk_norm_type=qk_norm_type,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None: ...@@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None:
), f"Output shape mismatch: {output.shape}" ), f"Output shape mismatch: {output.shape}"
def test_qk_norm_transformer_layer_output_difference() -> None: @pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_transformer_layer_output_difference(qk_norm_type) -> None:
"""Test that QK normalization actually changes TransformerLayer output compared to no normalization.""" """Test that QK normalization actually changes TransformerLayer output compared to no normalization."""
from transformer_engine.pytorch import TransformerLayer from transformer_engine.pytorch import TransformerLayer
...@@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
seq_len = 128 seq_len = 128
batch_size = 2 batch_size = 2
# Use same random seed to ensure identical weight initialization
current_rng_state = torch.get_rng_state()
current_cuda_rng_state = torch.cuda.get_rng_state()
# Reset to a known seed for reproducible initialization # Reset to a known seed for reproducible initialization
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
...@@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size, ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_qk_norm=True, qk_norm_type=qk_norm_type,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size, ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_qk_norm=False, qk_norm_type=None,
bias=False, bias=False,
device="cuda", device="cuda",
).cuda() ).cuda()
...@@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
output_no_norm = transformer_no_norm(hidden_states) output_no_norm = transformer_no_norm(hidden_states)
# Outputs should be different when QK normalization is enabled # Outputs should be different when QK normalization is enabled
assert not torch.allclose( assert not torch.allclose(output_with_norm, output_no_norm, atol=1e-6), (
output_with_norm, output_no_norm, atol=1e-6 f"QK normalization ({qk_norm_type}) should change the TransformerLayer output, but outputs"
), "QK normalization should change the TransformerLayer output, but outputs are identical" " are identical"
)
# Check that outputs have expected shapes and properties # Check that outputs have expected shapes and properties
assert output_with_norm.shape == ( assert output_with_norm.shape == (
...@@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None: ...@@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None:
assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf" assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf"
assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN" assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN"
assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf" assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf"
@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"])
def test_qk_norm_before_after_rope(qk_norm_type) -> None:
"""Test that QK normalization before and after RoPE works without errors."""
hidden_size = 256
num_attention_heads = 8
seq_len = 64
batch_size = 2
# Create model with QK norm after RoPE (default)
mha_after = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type=qk_norm_type,
qk_norm_before_rope=False,
bias=False,
device="cuda",
).cuda()
# Create model with QK norm before RoPE
mha_before = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type=qk_norm_type,
qk_norm_before_rope=True,
bias=False,
device="cuda",
).cuda()
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Create RoPE embeddings
head_dim = hidden_size // num_attention_heads
rotary_dim = head_dim // 2
rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32)
with torch.no_grad():
output_after_rope = mha_after(hidden_states, rotary_pos_emb=rotary_pos_emb)
output_before_rope = mha_before(hidden_states, rotary_pos_emb=rotary_pos_emb)
output_after_no_rope = mha_after(hidden_states)
output_before_no_rope = mha_before(hidden_states)
# Check output shapes and properties
expected_shape = (seq_len, batch_size, hidden_size)
for output in [
output_after_rope,
output_before_rope,
output_after_no_rope,
output_before_no_rope,
]:
assert output.shape == expected_shape, f"Output shape mismatch: {output.shape}"
assert not torch.isnan(output).any(), "Output contains NaN"
assert not torch.isinf(output).any(), "Output contains Inf"
assert output_after_rope.shape == output_before_rope.shape, "Outputs should have same shape"
assert mha_after.qk_norm_before_rope == False, "mha_after should have qk_norm_before_rope=False"
assert mha_before.qk_norm_before_rope == True, "mha_before should have qk_norm_before_rope=True"
def test_different_qk_norm_types_produce_different_outputs() -> None:
"""Test that different QK normalization types produce different outputs."""
hidden_size = 256
num_attention_heads = 8
seq_len = 128
batch_size = 2
# Use same random seed to ensure identical weight initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create model with L2 normalization
mha_l2 = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type="L2Normalization",
bias=False,
device="cuda",
).cuda()
# Reset to same seed for identical initialization
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Create model with RMS normalization
mha_rms = MultiheadAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
qk_norm_type="RMSNorm",
bias=False,
device="cuda",
).cuda()
# Create input tensors
hidden_states = torch.randn(
seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32
)
# Compare outputs with identical weights but different QK norm types
with torch.no_grad():
output_l2 = mha_l2(hidden_states)
output_rms = mha_rms(hidden_states)
# Outputs should be different when using different normalization types
assert not torch.allclose(
output_l2, output_rms, atol=1e-6
), "L2 and RMS normalization should produce different outputs, but outputs are identical"
# Check that outputs have expected shapes and properties
assert output_l2.shape == output_rms.shape, "L2 and RMS outputs should have same shape"
assert not torch.isnan(output_l2).any(), "L2 output contains NaN"
assert not torch.isinf(output_l2).any(), "L2 output contains Inf"
assert not torch.isnan(output_rms).any(), "RMS output contains NaN"
assert not torch.isinf(output_rms).any(), "RMS output contains Inf"
...@@ -192,12 +192,6 @@ class TestFP8Recipe: ...@@ -192,12 +192,6 @@ class TestFP8Recipe:
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
) )
# Get FP8 meta tensors
with te.fp8_autocast(fp8_recipe=recipe):
x_fp8_meta = op.get_quantizer("forward", 0)
w_fp8_meta = op.get_quantizer("forward", 1)
dy_fp8_meta = op.get_quantizer("backward", 0)
# Perform training steps # Perform training steps
x_history = [] x_history = []
w_history = [] w_history = []
...@@ -229,19 +223,30 @@ class TestFP8Recipe: ...@@ -229,19 +223,30 @@ class TestFP8Recipe:
y = op(x) y = op(x)
y.backward(dy) y.backward(dy)
def check_amax_history( def check_metas(
fp8_meta: dict, test_scale: float,
ref_amax_history: Iterable[float], test_amax_history: torch.Tensor,
) -> None: ref_amax_history_list: list[float],
"""Check that amax history matches expected values""" stage: str,
if len(ref_amax_history) > amax_history_len: ):
ref_amax_history = ref_amax_history[-amax_history_len:] """Check that meta tensors match expected values"""
# Compute amax
if len(ref_amax_history_list) > amax_history_len:
ref_amax_history_list = ref_amax_history_list[-(amax_history_len + 1) :]
ref_amax_history = torch.tensor( ref_amax_history = torch.tensor(
ref_amax_history, ref_amax_history_list,
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
) )
test_amax_history = fp8_meta.amax_history[:, 0] if amax_compute_algo == "max":
ref_amax = max(ref_amax_history_list)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history_list[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compare amax history
tols = dict(rtol=0, atol=0) tols = dict(rtol=0, atol=0)
torch.testing.assert_close( torch.testing.assert_close(
test_amax_history[-(step + 1) :], test_amax_history[-(step + 1) :],
...@@ -249,23 +254,6 @@ class TestFP8Recipe: ...@@ -249,23 +254,6 @@ class TestFP8Recipe:
**tols, **tols,
) )
def check_scale(
quantizer: Float8Quantizer,
ref_amax_history: Iterable[float],
stage: str,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-(amax_history_len + 1) :]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compute scale # Compute scale
max_val = { max_val = {
"forward": 448.0, "forward": 448.0,
...@@ -273,16 +261,26 @@ class TestFP8Recipe: ...@@ -273,16 +261,26 @@ class TestFP8Recipe:
}[stage] }[stage]
ref_scale = (max_val / ref_amax) / (2**margin) ref_scale = (max_val / ref_amax) / (2**margin)
# Check values in FP8 meta tensors # Compare scale
torch.testing.assert_close( torch.testing.assert_close(
quantizer.scale.item(), test_scale,
ref_scale, ref_scale,
) )
# Get scaling factors
x_test_scale = op.get_quantizer("forward", 0).scale.item()
w_test_scale = op.get_quantizer("forward", 1).scale.item()
dy_test_scale = op.get_quantizer("backward", 0).scale.item()
# Get amax histories
x_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 0]
w_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 1]
dy_test_history = op._fp8_metas["backward"][backward_key].amax_history[:, 0]
# Check that results match expected values # Check that results match expected values
check_scale(x_fp8_meta, x_history, "forward") check_metas(x_test_scale, x_test_history, x_history, "forward")
check_scale(w_fp8_meta, w_history, "forward") check_metas(w_test_scale, w_test_history, w_history, "forward")
check_scale(dy_fp8_meta, dy_history, "backward") check_metas(dy_test_scale, dy_test_history, dy_history, "backward")
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from dataclasses import dataclass
from typing import Optional from typing import Optional
from contextlib import nullcontext
import torch import torch
import pytest import pytest
...@@ -18,11 +16,9 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -18,11 +16,9 @@ from transformer_engine.pytorch.fp8 import (
fp8_model_init, fp8_model_init,
) )
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible, is_bf16_compatible,
get_cudnn_version,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
LayerNormLinear, LayerNormLinear,
...@@ -32,7 +28,6 @@ from transformer_engine.pytorch import ( ...@@ -32,7 +28,6 @@ from transformer_engine.pytorch import (
TransformerLayer, TransformerLayer,
RMSNorm, RMSNorm,
LayerNorm, LayerNorm,
get_cpu_offload_context,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -47,21 +42,17 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -47,21 +42,17 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from utils import dtype_tols from utils import ModelConfig
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run. # Record initial RNG state from script run.
seed = 1234 seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0")) NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
...@@ -79,88 +70,33 @@ if NVTE_TEST_NVINSPECT_ENABLED: ...@@ -79,88 +70,33 @@ if NVTE_TEST_NVINSPECT_ENABLED:
) )
def create_meta(scale_factor: float, size: int = 1): def is_fp8_supported(config: ModelConfig):
meta = tex.FP8TensorMeta() if (
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") config.max_seqlen_q * config.batch_size % 16
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor or config.max_seqlen_kv * config.batch_size % 16
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor ):
return meta return False
if config.hidden_size % 16 or config.hidden_size_kv % 16:
if IS_HIP_EXTENSION: return False
from functools import cache return True
@cache
def use_hipblaslt() -> bool:
return (os.getenv("NVTE_USE_HIPBLASLT") is not None
or os.getenv("NVTE_USE_ROCBLAS") is None )
def custom_amax_to_scale(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: torch.Tensor,
recipe: recipe.DelayedScaling,
) -> torch.Tensor:
"""Custom func to test recipe."""
sf = fp8_max / amax
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
return sf
def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe."""
return torch.min(amax_history, dim=0).values
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
global _cpu_rng_state, _cuda_rng_state
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@dataclass
class ModelConfig:
"""Transformer model configuration"""
num_layers: int
seq_len: int
batch_size: int
hidden_size: int
num_attention_heads: int
kv_channels: Optional[int] = None
def is_fp8_supported(self):
if self.seq_len * self.batch_size % 16:
return False
if self.hidden_size % 16:
return False
return True
model_configs = { model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12), "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
"small": ModelConfig(2, 32, 2, 64, 2), "small": ModelConfig(2, 32, 2, 32, num_layers=2),
"weird": ModelConfig(2, 37, 3, 69, 3), "weird": ModelConfig(3, 37, 3, 23, num_layers=2),
"large": ModelConfig(1, 128, 2, 512, 4, 128), "large": ModelConfig(2, 128, 4, 128, num_layers=1),
} }
fp8_recipes = [ fp8_recipes = []
None, # Test non-FP8 if mxfp8_available:
recipe.MXFP8BlockScaling(), # Test default fp8_recipes.append(recipe.MXFP8BlockScaling())
recipe.Float8CurrentScaling(), # Test default if fp8_block_scaling_available:
recipe.Float8BlockScaling(), # Test default fp8_recipes.append(recipe.Float8BlockScaling())
recipe.DelayedScaling(), # Test default if fp8_available:
recipe.DelayedScaling( # Test most_recent algo fp8_recipes.append(recipe.Float8CurrentScaling())
amax_history_len=16, fp8_recipes.append(recipe.DelayedScaling())
amax_compute_algo="most_recent", fp8_recipes.append(None)
),
recipe.DelayedScaling( # Test custom amax and scale compute algo
fp8_format=recipe.Format.E4M3,
amax_compute_algo=custom_amax_compute,
scaling_factor_compute_algo=custom_amax_to_scale,
),
]
param_types = [torch.float32, torch.float16] param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_compatible(): # bf16 requires sm_80 or higher
...@@ -184,66 +120,9 @@ def reset_global_fp8_state(): ...@@ -184,66 +120,9 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
static_input = torch.randn(
config.seq_len,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
requires_grad=True,
)
static_target = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
use_fp8 = fp8_recipe is not None
if skip_wgrad:
_disable_wgrads(block)
# Pre graph capture warmup in a separate stream.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
optimizer.step()
torch.cuda.current_stream().wait_stream(s)
# Capture.
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
static_output = block(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
optimizer.step()
# Fills the graph's input memory with new data to compute on
with torch.no_grad():
static_input.copy_(real_input)
static_target.copy_(real_target)
g.replay()
torch.cuda.synchronize()
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
...@@ -251,7 +130,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -251,7 +130,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states.retain_grad() te_inp_hidden_states.retain_grad()
te_inp_attn_mask = torch.randint( te_inp_attn_mask = torch.randint(
2, 2,
(1, 1, config.seq_len, config.seq_len), (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
) )
...@@ -278,14 +157,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -278,14 +157,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
te_inp_attn_mask = torch.randint( te_inp_attn_mask = torch.randint(
2, 2,
(1, 1, config.seq_len, config.seq_len), (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
) )
...@@ -316,9 +195,9 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci ...@@ -316,9 +195,9 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}." assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
...@@ -327,16 +206,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): ...@@ -327,16 +206,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
if skip_wgrad: if skip_wgrad:
_disable_wgrads(block) _disable_wgrads(block)
if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context: with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states) te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -344,7 +216,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): ...@@ -344,7 +216,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
...@@ -352,7 +224,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -352,7 +224,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_attn_mask = torch.randint( te_inp_attn_mask = torch.randint(
2, 2,
(config.batch_size, 1, 1, config.seq_len), (config.batch_size, 1, 1, config.max_seqlen_q),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
) )
...@@ -370,21 +242,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -370,21 +242,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
te_inp_attn_mask = torch.randint( te_inp_attn_mask = torch.randint(
2, 2,
(1, 1, config.seq_len, config.seq_len), (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
) )
enc_dec_attn_mask = torch.randint( enc_dec_attn_mask = torch.randint(
2, 2,
(config.batch_size, 1, 1, config.seq_len), (config.batch_size, 1, 1, config.max_seqlen_kv),
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
) )
...@@ -412,7 +284,7 @@ def _test_sanity_common( ...@@ -412,7 +284,7 @@ def _test_sanity_common(
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn( te_inp = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
requires_grad=not skip_dgrad, requires_grad=not skip_dgrad,
...@@ -440,7 +312,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) ...@@ -440,7 +312,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn( te_inp = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
device="cuda", device="cuda",
requires_grad=True, requires_grad=True,
) )
...@@ -495,13 +367,7 @@ def test_sanity_layernorm_linear( ...@@ -495,13 +367,7 @@ def test_sanity_layernorm_linear(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -529,13 +395,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba ...@@ -529,13 +395,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -562,16 +422,10 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -562,16 +422,10 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest.skip("Quantized model parameters are not supported in debug mode.") pytest.skip("Quantized model parameters are not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs * config.seq_len num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
...@@ -607,16 +461,10 @@ def test_sanity_grouped_linear( ...@@ -607,16 +461,10 @@ def test_sanity_grouped_linear(
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs = bs * 16 bs = bs * 16
num_tokens = bs * config.seq_len * (num_gemms - 1) num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
...@@ -628,7 +476,7 @@ def test_sanity_grouped_linear( ...@@ -628,7 +476,7 @@ def test_sanity_grouped_linear(
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
m_splits = [bs * config.seq_len] * num_gemms m_splits = [bs * config.max_seqlen_q] * num_gemms
if empty_split == "first": if empty_split == "first":
m_splits[0] = 0 m_splits[0] = 0
elif empty_split == "last": elif empty_split == "last":
...@@ -666,13 +514,7 @@ def test_sanity_layernorm_mlp( ...@@ -666,13 +514,7 @@ def test_sanity_layernorm_mlp(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -697,36 +539,24 @@ def test_sanity_layernorm_mlp( ...@@ -697,36 +539,24 @@ def test_sanity_layernorm_mlp(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", ["gelu", "swiglu"])
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt( def test_sanity_gpt(
dtype, dtype,
fp8_recipe, fp8_recipe,
model, model,
skip_wgrad, skip_wgrad,
zero_centered_gamma,
bias, bias,
activation, activation,
normalization, normalization,
parallel_attention_mlp, parallel_attention_mlp,
cpu_offload,
): ):
if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("CPU offload is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -736,7 +566,7 @@ def test_sanity_gpt( ...@@ -736,7 +566,7 @@ def test_sanity_gpt(
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
...@@ -745,7 +575,6 @@ def test_sanity_gpt( ...@@ -745,7 +575,6 @@ def test_sanity_gpt(
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
bias=bias, bias=bias,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
...@@ -753,7 +582,7 @@ def test_sanity_gpt( ...@@ -753,7 +582,7 @@ def test_sanity_gpt(
parallel_attention_mlp=parallel_attention_mlp, parallel_attention_mlp=parallel_attention_mlp,
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
def test_sanity_gpt_126m(): def test_sanity_gpt_126m():
...@@ -770,12 +599,10 @@ def test_sanity_gpt_126m(): ...@@ -770,12 +599,10 @@ def test_sanity_gpt_126m():
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
model="126m", model="126m",
skip_wgrad=False, skip_wgrad=False,
zero_centered_gamma=True,
bias=True, bias=True,
activation="gelu", activation="gelu",
normalization="LayerNorm", normalization="LayerNorm",
parallel_attention_mlp=False, parallel_attention_mlp=False,
cpu_offload=False,
) )
...@@ -783,19 +610,14 @@ def test_sanity_gpt_126m(): ...@@ -783,19 +610,14 @@ def test_sanity_gpt_126m():
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -805,7 +627,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -805,7 +627,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
...@@ -814,7 +636,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, ...@@ -814,7 +636,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=True, apply_residual_connection_post_layernorm=True,
output_layernorm=True, output_layernorm=True,
zero_centered_gamma=zero_centered_gamma,
self_attn_mask_type="causal", self_attn_mask_type="causal",
normalization=normalization, normalization=normalization,
device="cuda", device="cuda",
...@@ -835,7 +656,6 @@ def test_sanity_bert_126m(): ...@@ -835,7 +656,6 @@ def test_sanity_bert_126m():
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
model="126m", model="126m",
skip_wgrad=False, skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm", normalization="LayerNorm",
) )
...@@ -844,19 +664,14 @@ def test_sanity_bert_126m(): ...@@ -844,19 +664,14 @@ def test_sanity_bert_126m():
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -866,7 +681,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no ...@@ -866,7 +681,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
...@@ -876,7 +691,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no ...@@ -876,7 +691,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
layer_type="decoder", layer_type="decoder",
zero_centered_gamma=zero_centered_gamma,
normalization=normalization, normalization=normalization,
device="cuda", device="cuda",
) )
...@@ -896,7 +710,6 @@ def test_sanity_T5_126m(): ...@@ -896,7 +710,6 @@ def test_sanity_T5_126m():
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
model="126m", model="126m",
skip_wgrad=False, skip_wgrad=False,
zero_centered_gamma=False,
normalization="LayerNorm", normalization="LayerNorm",
) )
...@@ -909,13 +722,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -909,13 +722,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -925,7 +732,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -925,7 +732,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
...@@ -941,18 +748,11 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -941,18 +748,11 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_drop_path(dtype, fp8_recipe, model):
def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -962,7 +762,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): ...@@ -962,7 +762,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
...@@ -975,7 +775,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): ...@@ -975,7 +775,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
device="cuda", device="cuda",
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) _test_sanity_e2e(block, dtype, config, fp8_recipe, False)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -986,13 +786,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -986,13 +786,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -1002,7 +796,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -1002,7 +796,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
...@@ -1015,27 +809,18 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -1015,27 +809,18 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
device="cuda", device="cuda",
) )
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad):
def test_sanity_gradient_accumulation_fusion(
dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not fp8_available: if not is_fp8_supported(config):
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -1045,7 +830,7 @@ def test_sanity_gradient_accumulation_fusion( ...@@ -1045,7 +830,7 @@ def test_sanity_gradient_accumulation_fusion(
block = TransformerLayer( block = TransformerLayer(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_heads,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1, hidden_dropout=0.1,
...@@ -1054,7 +839,6 @@ def test_sanity_gradient_accumulation_fusion( ...@@ -1054,7 +839,6 @@ def test_sanity_gradient_accumulation_fusion(
params_dtype=dtype, params_dtype=dtype,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True, fuse_qkv_params=True,
fuse_wgrad_accumulation=True, fuse_wgrad_accumulation=True,
device="cuda", device="cuda",
...@@ -1063,56 +847,6 @@ def test_sanity_gradient_accumulation_fusion( ...@@ -1063,56 +847,6 @@ def test_sanity_gradient_accumulation_fusion(
_test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
if IS_HIP_EXTENSION:
if not use_hipblaslt():
pytest.skip("CUDA graph capture not supported with rocBLAS path")
config = model_configs[model]
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling():
pytest.skip("cuda graph not supported for float8_block_scaling recipe")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
fuse_qkv_params=True,
normalization=normalization,
device="cuda",
)
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast(): def test_model_multiple_cast():
a = torch.zeros((16, 16), device="cuda") a = torch.zeros((16, 16), device="cuda")
m = Linear(16, 32) m = Linear(16, 32)
...@@ -1167,133 +901,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ...@@ -1167,133 +901,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
torch.cuda.synchronize() torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
config = model_configs[model]
outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
outputs_checkpoint_v1_6 = _run_attention_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
# Check that results match
tols = dtype_tols(dtype)
if dtype in (torch.float16, torch.bfloat16):
tols.update(dict(rtol=2e-2, atol=2e-3))
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
torch.testing.assert_close(
test,
ref,
**tols,
)
for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
torch.testing.assert_close(
test,
ref,
**tols,
)
def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_enabled,
fp8_mha=False,
)
reset_rng_states()
hidden_states = torch.randn(
(config.seq_len, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
def get_model(dtype, config):
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
device="cuda",
)
return block
block = get_model(dtype, config)
for i in range(steps // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
if checkpoint:
sd = block.state_dict()
if mimic_v1_6:
sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
"self_attention.core_attention._extra_state"
]
del sd["self_attention.core_attention._extra_state"]
torch.save(sd, path)
param_grads = []
for p in block.parameters():
if p.requires_grad:
param_grads.append(p.grad.clone())
_cpu_rng_state_new = torch.get_rng_state()
_cuda_rng_state_new = torch.cuda.get_rng_state()
del block
block = get_model(dtype, config)
block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new)
for p in block.parameters():
if p.requires_grad:
p.grad = param_grads.pop(0)
assert not param_grads, "Oops!"
for i in range((steps + 1) // 2):
with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
output = block(hidden_states, None)
loss = output.sum()
loss.backward()
torch.cuda.synchronize()
if os.path.exists(path):
os.remove(path)
outputs = [output, hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor(): def test_replace_raw_data_for_float8tensor():
"""Test the functionality of replace_raw_data""" """Test the functionality of replace_raw_data"""
...@@ -1389,6 +996,32 @@ def test_sanity_checkpointing_on_callables(): ...@@ -1389,6 +996,32 @@ def test_sanity_checkpointing_on_callables():
torch.testing.assert_close(grad_checkpoint, grad_standard) torch.testing.assert_close(grad_checkpoint, grad_standard)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_linear_frozen_weights_memory_default_recipe():
"""Test that memory usage is optimized when weights are frozen for MXFP8."""
dim = 1024
linear = Linear(dim, dim, bias=False)
x = torch.randn(dim, dim, requires_grad=True, device="cuda")
# Freeze weights
linear.weight.requires_grad = False
# Forward and backward pass with FP8
with fp8_autocast():
o = linear(x)
g_o = torch.randn_like(o)
max_memory_before_backward = torch.cuda.max_memory_allocated()
o.backward(g_o)
max_memory_after_backward = torch.cuda.max_memory_allocated()
memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6
assert memory_diff < 5.5, (
f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the"
" grad_output should be quantized only columnwise."
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"module_name", "module_name",
("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"), ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"),
......
...@@ -4,12 +4,24 @@ ...@@ -4,12 +4,24 @@
from __future__ import annotations from __future__ import annotations
import logging
import os
from contextlib import contextmanager
import pytest
import torch import torch
import transformer_engine import transformer_engine
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
get_attention_backend,
AttentionParams,
AttentionLogging,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
...@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ...@@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name == "fp8_block_scaling": if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling() return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})") raise ValueError(f"Unsupported quantization scheme ({name})")
# Cached RNG state
_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def reset_rng_states() -> None:
"""Revert to deterministic RNG state"""
global _rng_states
if _rng_states is None:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
_rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state())
else:
cpu_rng_state, cuda_rng_state = _rng_states
torch.set_rng_state(cpu_rng_state)
torch.cuda.set_rng_state(cuda_rng_state)
class ModelConfig:
def __init__(
self,
batch_size: int,
max_seqlen_q: int,
num_heads: int,
head_dim_qk: int,
max_seqlen_kv: int = None,
num_gqa_groups: int = None,
head_dim_v: int = None,
dropout_p: float = 0.0,
attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias",
alibi_type: str = "none",
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
eps: float = 1e-5,
):
self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv
self.num_heads = num_heads
self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
if self.head_dim_qk == self.head_dim_v:
self.kv_channels = self.head_dim_qk
else:
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
self.eps = eps
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def get_available_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check for all available attention backends that support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging()
with logging_context(highest_level=AttentionLogging._log_level):
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
...@@ -126,6 +126,7 @@ if(USE_CUDA) ...@@ -126,6 +126,7 @@ if(USE_CUDA)
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu activation/gelu.cu
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
fused_attn/context_parallel.cu fused_attn/context_parallel.cu
...@@ -189,6 +190,7 @@ else() ...@@ -189,6 +190,7 @@ else()
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
...@@ -347,6 +349,8 @@ if(USE_CUDA) ...@@ -347,6 +349,8 @@ if(USE_CUDA)
string_code_transpose_rtc_cast_transpose_cu) string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu) string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu)
make_string_header_from_file(utils.cuh make_string_header_from_file(utils.cuh
string_code_utils_cuh) string_code_utils_cuh)
else() else()
...@@ -358,6 +362,8 @@ else() ...@@ -358,6 +362,8 @@ else()
string_code_transpose_rtc_cast_transpose_cu) string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.hip make_string_header_from_file(transpose/rtc/transpose.hip
string_code_transpose_rtc_transpose_cu) string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(transpose/rtc/swap_first_dims.cu
string_code_transpose_rtc_swap_first_dims_cu)
endif() endif()
...@@ -385,6 +391,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) ...@@ -385,6 +391,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
set_source_files_properties(activation/gelu.cu set_source_files_properties(activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
util/cast.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
endif() endif()
......
...@@ -246,6 +246,18 @@ def _load_cudnn(): ...@@ -246,6 +246,18 @@ def _load_cudnn():
if found: if found:
return handle return handle
# Attempt to locate libcudnn via ldconfig
libs = subprocess.check_output(
f"ldconfig -p | grep 'libcudnn{_get_sys_extension()}'", shell=True
)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcudnn" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
...@@ -267,12 +279,12 @@ def _load_nvrtc(): ...@@ -267,12 +279,12 @@ def _load_nvrtc():
return handle return handle
# Attempt to locate NVRTC via ldconfig # Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) libs = subprocess.check_output(
f"ldconfig -p | grep 'libnvrtc{_get_sys_extension()}'", shell=True
)
libs = libs.decode("utf-8").split("\n") libs = libs.decode("utf-8").split("\n")
sos = [] sos = []
for lib in libs: for lib in libs:
if "stub" in lib or "libnvrtc-builtins" in lib:
continue
if "libnvrtc" in lib and "=>" in lib: if "libnvrtc" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip()) sos.append(lib.split(">")[1].strip())
if sos: if sos:
......
...@@ -189,14 +189,26 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -189,14 +189,26 @@ CommOverlapCore::~CommOverlapCore() {
if (_atomic_gemm) cudaFree(_counter.dptr()); if (_atomic_gemm) cudaFree(_counter.dptr());
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); for (size_t i = 0; i < _stream_compute.size(); i++) {
cudaStreamSynchronize(_stream_compute[i]);
cudaStreamDestroy(_stream_compute[i]);
}
auto error = cudaGetLastError();
if (error != cudaSuccess) {
NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error));
}
if (_comm_created) { if (_comm_created) {
try {
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm); destroy_communicator_mpi(_ub_comm);
#else #else
destroy_communicator(_ub_comm); destroy_communicator(_ub_comm);
#endif #endif
} catch (const std::exception &e) {
NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what());
}
_comm_created = false; _comm_created = false;
} }
} }
...@@ -382,6 +394,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -382,6 +394,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
CommOverlapBase::~CommOverlapBase() { CommOverlapBase::~CommOverlapBase() {
cudaEventDestroy(_start_d2dcopy); cudaEventDestroy(_start_d2dcopy);
cudaStreamSynchronize(_stream_comm);
cudaStreamDestroy(_stream_comm); cudaStreamDestroy(_stream_comm);
} }
...@@ -704,6 +717,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -704,6 +717,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::split_overlap_rs } // CommOverlapBase::split_overlap_rs
void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) {
int comm_bytes = _ubuf.bytes();
int comm_bytes_per_rank = comm_bytes / _tp_size;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
recv_stream);
for (auto stream : {send_stream, recv_stream}) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
// We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0));
}
}
/*************************************************************************************************** /***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange) * Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -2652,6 +2652,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds ...@@ -2652,6 +2652,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
} }
} }
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
for (int j = 1; j < tp_size; j++) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * tp_rank;
int recv_offset = dstoffset + bytes_per_slice * tp_rank;
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
}
}
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
for (int j = tp_size - 1; j > 0; j--) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * i;
int recv_offset = dstoffset + bytes_per_slice * i;
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
}
}
// producer // producer
static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) {
// Decrement atomic val to signal current output tile finish // Decrement atomic val to signal current output tile finish
......
...@@ -312,4 +312,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp ...@@ -312,4 +312,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp
void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream); void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream);
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
...@@ -98,6 +98,9 @@ void checkCuDriverContext(CUstream stream) { ...@@ -98,6 +98,9 @@ void checkCuDriverContext(CUstream stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
return; return;
#else #else
// Ensure the thread's "current" CUDA context is set.
cuda_driver::ensure_context_exists();
CUcontext ctx; CUcontext ctx;
const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx); const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
switch (driver_status) { switch (driver_status) {
...@@ -167,10 +170,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -167,10 +170,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
(offset_elems * type_num_bits) / 8); (offset_elems * type_num_bits) / 8);
NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT),
"Tensor data pointer must be 16B aligned"); "Tensor data pointer must be 16B aligned");
const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
"-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
......
...@@ -94,7 +94,7 @@ struct SimpleTensor { ...@@ -94,7 +94,7 @@ struct SimpleTensor {
nvte_make_shape(this->shape.data(), this->shape.size())}; nvte_make_shape(this->shape.data(), this->shape.size())};
} }
int numel() const { size_t numel() const {
size_t acc = 1; size_t acc = 1;
for (const auto &dim : shape) { for (const auto &dim : shape) {
acc *= dim; acc *= dim;
...@@ -737,7 +737,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128; ...@@ -737,7 +737,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4;
// Alignment requirements for the Tensor Memory Accelerator (TMA) // Alignment requirements for the Tensor Memory Accelerator (TMA)
constexpr int TMA_gmem_alignment = 16; // global memory address alignment constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment
constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment
inline bool is_aligned_ptr(const void *ptr, size_t alignment) { inline bool is_aligned_ptr(const void *ptr, size_t alignment) {
return reinterpret_cast<uintptr_t>(ptr) % alignment == 0; return reinterpret_cast<uintptr_t>(ptr) % alignment == 0;
......
...@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset) { !requires_64bit_ragged_offset &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) { if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else { } else {
...@@ -239,20 +241,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -239,20 +241,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
// 9.10: any head_dim + any arch + fprop + paged // 9.10.2: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91000 && (!is_training && cudnn_runtime_version >= 91002 &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 ||
(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) && cudnn_runtime_version >= 91100)) &&
// 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 && (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training &&
head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 &&
head_dim_qk != head_dim_v))) && !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) &&
// bias type // bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 && (cudnn_runtime_version >= 8906 &&
...@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)))) && dropout == 0.0)))) &&
// check 64-bit ragged offset support // check 64-bit ragged offset support
(supported_ragged_offset_size)) { (supported_ragged_offset_size) &&
// 9.10.0/9.10.1: known bugs with SDPA F16
(cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) {
flag_arb = true; flag_arb = true;
} }
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......
...@@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, ...@@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert * Section: Reduce to get the sum of aggregated_probs_per_expert
*/ */
CompType intermediate_result = CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id);
__syncwarp(); __syncwarp();
if (lane_id == 0) { if (lane_id == 0) {
...@@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, ...@@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert * Section: Reduce to get the sum of aggregated_probs_per_expert
*/ */
CompType intermediate_result = CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id);
__syncwarp(); __syncwarp();
if (lane_id == 0) { if (lane_id == 0) {
......
...@@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi ...@@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi
if (score_function == 0) { if (score_function == 0) {
if (topk > 1) { if (topk > 1) {
auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, sum, lane_id); auto sum_logits =
warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id);
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] = static_cast<DataType>(static_cast<double>(local_logits[i]) / local_logits[i] = static_cast<DataType>(static_cast<double>(local_logits[i]) /
(static_cast<double>(sum_logits) + epsilon)); (static_cast<double>(sum_logits) + epsilon));
...@@ -231,13 +232,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int ...@@ -231,13 +232,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int
*/ */
// Sigmoid Post-processing bwd when topk > 1 // Sigmoid Post-processing bwd when topk > 1
if (topk > 1 && score_function == 0) { if (topk > 1 && score_function == 0) {
auto sum_fwd_input = warp_reduce_on_shmem(local_act_from_fwd, num_experts, sum, lane_id); auto sum_fwd_input =
warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id);
// Put the result of output * grad to the comp_buf // Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i]; local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i];
} }
__syncwarp(); __syncwarp();
auto sum_Output_x_Grad = warp_reduce_on_shmem(local_comp_buf, num_experts, sum, lane_id); auto sum_Output_x_Grad =
warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id);
// In-place update // In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_grad[i] = local_grad[i] =
......
...@@ -220,7 +220,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( ...@@ -220,7 +220,7 @@ __global__ void fused_topk_with_score_function_forward_kernel(
// score_function == 0 means sigmoid // score_function == 0 means sigmoid
if (score_function == 0) { if (score_function == 0) {
if (topk > 1) { if (topk > 1) {
double sum_scores = warp_reduce_on_shmem(topk_scores, topk, sum, lane_id); double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id);
for (int i = lane_id; i < topk; i += kThreadsPerWarp) { for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
topk_scores[i] = static_cast<double>(topk_scores[i]) / (sum_scores + epsilon); topk_scores[i] = static_cast<double>(topk_scores[i]) / (sum_scores + epsilon);
} }
...@@ -362,7 +362,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( ...@@ -362,7 +362,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */ local_act_from_fwd, /*data ptr = */ local_act_from_fwd,
/*mask ptr = */ local_routing_map, /*mask ptr = */ local_routing_map,
/*data size = */ num_experts, /*data size = */ num_experts,
/*reduce func = */ sum, lane_id); /*reduce func = */ ReduceFuncType::SUM, lane_id);
// Put the result of output * grad to the comp_buf // Put the result of output * grad to the comp_buf
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_comp_buf[i] = (local_routing_map[i] ? static_cast<double>(local_grad[i]) * local_comp_buf[i] = (local_routing_map[i] ? static_cast<double>(local_grad[i]) *
...@@ -374,7 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( ...@@ -374,7 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel(
/*data ptr = */ local_comp_buf, /*data ptr = */ local_comp_buf,
/*mask ptr = */ local_routing_map, /*mask ptr = */ local_routing_map,
/*data size = */ num_experts, /*data size = */ num_experts,
/*reduce func = */ sum, lane_id); /*reduce func = */ ReduceFuncType::SUM, lane_id);
// In-place update // In-place update
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
if (local_routing_map[i]) { if (local_routing_map[i]) {
......
...@@ -30,14 +30,28 @@ __device__ inline T sum(T a, T b) { ...@@ -30,14 +30,28 @@ __device__ inline T sum(T a, T b) {
return a + b; return a + b;
} }
enum ReduceFuncType {
SUM,
MAX,
};
template <typename T> template <typename T>
__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_func)(T, T), __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type,
int lane_id) { int lane_id) {
T (*reduce_func)(T, T);
double default_val = 0;
if (type == ReduceFuncType::SUM) {
reduce_func = sum;
default_val = 0;
} else if (type == ReduceFuncType::MAX) {
reduce_func = max;
default_val = -std::numeric_limits<double>::infinity();
}
// Some value is hanlded in local thread // Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread // Reduce the value in local thread
volatile double val = volatile double val = lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : default_val;
lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : static_cast<double>(0);
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
val = reduce_func(val, data_ptr[i]); val = reduce_func(val, data_ptr[i]);
} }
...@@ -69,13 +83,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i ...@@ -69,13 +83,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i
template <typename T> template <typename T>
__device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size, __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size,
T (*reduce_func)(T, T), int lane_id) { ReduceFuncType type, int lane_id) {
T (*reduce_func)(T, T);
double default_val = 0;
if (type == ReduceFuncType::SUM) {
reduce_func = sum;
default_val = 0;
} else if (type == ReduceFuncType::MAX) {
reduce_func = max;
default_val = -std::numeric_limits<double>::infinity();
}
// Some value is hanlded in local thread // Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread // Reduce the value in local thread
volatile double val = lane_id < data_size && mask[lane_id] volatile double val =
? static_cast<double>(data_ptr[lane_id]) lane_id < data_size && mask[lane_id] ? static_cast<double>(data_ptr[lane_id]) : default_val;
: static_cast<double>(0);
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
if (mask[i]) { if (mask[i]) {
val = reduce_func(val, data_ptr[i]); val = reduce_func(val, data_ptr[i]);
...@@ -128,7 +151,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_ ...@@ -128,7 +151,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
float sum_Output_x_Grad = warp_reduce_on_shmem( float sum_Output_x_Grad = warp_reduce_on_shmem(
/*data ptr = */ comp_buf, /*data ptr = */ comp_buf,
/*data size = */ data_size, /*data size = */ data_size,
/*reduce func = */ sum, lane_id); /*reduce func = */ ReduceFuncType::SUM, lane_id);
// In-place update // In-place update
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
if (mask) { if (mask) {
...@@ -147,14 +170,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_ ...@@ -147,14 +170,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_
template <typename DataType> template <typename DataType>
__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) { __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) {
// 1. compute the max of value // 1. compute the max of value
float max_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, max, lane_id)); float max_val =
static_cast<float>(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id));
// 2. value -> exp_value // 2. value -> exp_value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(exp(static_cast<float>(scores[i]) - max_val)); scores[i] = static_cast<float>(exp(static_cast<float>(scores[i]) - max_val));
} }
__syncwarp(); __syncwarp();
// 3. compute the sum of exp_value // 3. compute the sum of exp_value
float sum_val = static_cast<float>(warp_reduce_on_shmem(scores, data_size, sum, lane_id)); float sum_val =
static_cast<float>(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id));
// 4. update the softmax value // 4. update the softmax value
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
scores[i] = static_cast<float>(scores[i]) / sum_val; scores[i] = static_cast<float>(scores[i]) / sum_val;
...@@ -165,19 +190,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i ...@@ -165,19 +190,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i
template <typename T> template <typename T>
__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices,
T *topk_scores, int lane_id) { T *topk_scores, int lane_id) {
// Check if the index is masked by the later iteration
auto is_masked = [&topk_indices](int k, int index) {
if (k == 0) return false;
for (int i = 0; i < k; i++) {
if (topk_indices[i] == index) return true;
}
return false;
};
// Topk Times: Find the max value and its index // Topk Times: Find the max value and its index
// Then mask it, and record the index in the topk_indices // Then mask it, and record the index in the topk_indices
// After looping topk times, the topk_indices will be the topk indices // After looping topk times, the topk_indices will be the topk indices
for (int k = 0; k < topk; k++) { for (int k = 0; k < topk; k++) {
// Find the max value and its index // Find the max value and its index
volatile double val = volatile double val = (lane_id < data_size && !is_masked(k, lane_id))
(lane_id < data_size) ? static_cast<double>(scores[lane_id]) : static_cast<double>(0); ? static_cast<double>(scores[lane_id])
: -std::numeric_limits<double>::infinity();
volatile int index = (lane_id < data_size) ? lane_id : 0; volatile int index = (lane_id < data_size) ? lane_id : 0;
// Some value is hanlded in local thread // Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread // Reduce the value in local thread
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
volatile double cur_val = scores[i]; volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits<double>::infinity()
: static_cast<double>(scores[i]);
if (cur_val > val) { if (cur_val > val) {
val = cur_val; val = cur_val;
index = i; index = i;
...@@ -200,17 +235,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i ...@@ -200,17 +235,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
if (lane_id == 0) { if (lane_id == 0) {
topk_indices[k] = index; topk_indices[k] = index;
topk_scores[k] = val; topk_scores[k] = val;
scores[index] =
static_cast<double>(-1.0) - val; // make the selected experts using val = - 1 - val
} }
__syncwarp(); __syncwarp();
} }
// Reset the scores to the original value
for (int i = lane_id; i < topk; i += kThreadsPerWarp) {
scores[topk_indices[i]] =
static_cast<double>(-1.0) - static_cast<double>(scores[topk_indices[i]]);
}
} }
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
......
...@@ -253,8 +253,9 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas ...@@ -253,8 +253,9 @@ using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublas
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa,
cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, float alpha, float beta, bool use_split_accumulator, int math_sm_count,
int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter,
cudaStream_t stream) {
// Tensor dims in row-major order // Tensor dims in row-major order
const int A0 = inputA->flat_first_dim(); const int A0 = inputA->flat_first_dim();
const int A1 = inputA->flat_last_dim(); const int A1 = inputA->flat_last_dim();
...@@ -310,13 +311,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -310,13 +311,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"fp8 Aux output for gemm + gelu fusion not supported!"); "fp8 Aux output for gemm + gelu fusion not supported!");
} }
if (is_fp8_dtype(outputD->data.dtype)) { if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!");
} }
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatmulDesc_t operationDesc = nullptr;
...@@ -601,7 +598,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -601,7 +598,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
static_cast<const void *>(&one), /* alpha */ static_cast<const void *>(&alpha), /* alpha */
param.A, /* A */ param.A, /* A */
Adesc, param.B, /* B */ Adesc, param.B, /* B */
Bdesc, static_cast<const void *>(&beta), /* beta */ Bdesc, static_cast<const void *>(&beta), /* beta */
...@@ -752,8 +749,27 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -752,8 +749,27 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
#else #else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false,
#endif //__HIP_PLATFORM_AMD__ nullptr, stream);
#endif //__HIP_PLATFORM_AMD__
}
void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D,
const NVTETensor bias, NVTETensor pre_gelu_out, bool transa,
bool transb, bool grad, NVTETensor workspace, float alpha, float beta,
bool use_split_accumulator, int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm_scaled);
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream);
} }
void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...@@ -846,8 +862,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -846,8 +862,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#else #else
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split,
inputCounter, stream); n_split, gemm_producer, inputCounter, stream);
#endif //__HIP_PLATFORM_AMD__ #endif //__HIP_PLATFORM_AMD__
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment