"src/targets/vscode:/vscode.git/clone" did not exist on "03929873b3098c5d4756741d117af0b2d7173744"
Unverified Commit 73f8d90f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] cuda graph support (#575)



* FP8 cuda graphs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix numerics
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* exclude torch compile from numerics tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More numerics fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm fusion from unfused path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
parent 1b20f2d6
......@@ -41,4 +41,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.onnx_export
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
......@@ -9,9 +9,10 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
......@@ -5,7 +5,6 @@
import functools
from importlib.metadata import version
import os
import math
from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging
......@@ -28,15 +27,9 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_bwd,
fused_attn_fwd,
)
from transformer_engine.pytorch.distributed import (
_set_cuda_rng_state,
CudaRNGStatesTracker,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
TransformerEngineBaseModule,
_prepare_backward,
)
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
......@@ -58,10 +51,18 @@ _cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
fp8.FP8GlobalStateManager.reset()
@functools.cache
def _cudnn_version() -> Tuple[int, int, int]:
......@@ -71,6 +72,7 @@ def _cudnn_version() -> Tuple[int, int, int]:
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
class ModelConfig:
def __init__(
self,
......@@ -103,6 +105,7 @@ class ModelConfig:
self.num_layers = num_layers
self.bias_shape = bias_shape
def _is_fused_attention_supported(
config: ModelConfig,
dtype: torch.dtype,
......@@ -151,24 +154,28 @@ def _is_fused_attention_supported(
return True, backends
return False, backends
@functools.cache
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")
@functools.cache
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
@functools.cache
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")
def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
......@@ -184,6 +191,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False
return True
def _is_unfused_attention_supported(config: ModelConfig) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type):
......@@ -192,6 +200,7 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool:
return False
return True
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
......@@ -200,11 +209,13 @@ model_configs_base = {
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
}
param_types = [torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
......@@ -216,6 +227,7 @@ def get_swa(seq_q, seq_kv, w=None):
ml = ~ ml
return w, ml
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
......@@ -313,6 +325,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
......@@ -321,6 +334,7 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False)
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
......@@ -337,6 +351,7 @@ model_configs_mask = {
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
......@@ -345,6 +360,7 @@ def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
......@@ -373,6 +389,7 @@ model_configs_bias = {
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
......@@ -381,6 +398,7 @@ def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
......@@ -398,6 +416,7 @@ model_configs_bias_shapes = {
"causal", "alibi", bias_shape='bhss', alibi_type='custom'),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
......@@ -413,6 +432,8 @@ model_configs_swa = {
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
......@@ -428,6 +449,8 @@ model_configs_alibi_slopes = {
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
......@@ -436,6 +459,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
......@@ -443,6 +467,7 @@ qkv_layouts = [
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
]
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
......@@ -455,6 +480,7 @@ model_configs_layout = {
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout])
......@@ -464,6 +490,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
......@@ -646,6 +673,7 @@ def _run_dot_product_attention(
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
......@@ -658,6 +686,7 @@ model_configs_te_layer = {
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
......@@ -742,6 +771,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
......@@ -755,6 +785,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format):
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
......@@ -780,6 +811,7 @@ def test_te_layer_mqa_gqa(dtype, model_configs, model):
test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE)
def _run_transformer_layer(
dtype: torch.dtype,
config: ModelConfig,
......@@ -912,8 +944,10 @@ model_configs_fp8 = {
"fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
}
param_types_fp8 = [torch.float16]
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
......@@ -946,6 +980,7 @@ def test_dpa_fp8(dtype, model):
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
def _run_dpa_fp8(dtype, config, backend):
"""Run FusedAttention FP8 backend, i.e.
fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
......@@ -989,6 +1024,7 @@ def _run_dpa_fp8(dtype, config, backend):
dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).transpose(0,1).contiguous())
def _run_dpa_fp8_ref(dtype, config, backend):
"""Run UnfusedDotProductAttention as a reference, i.e.
plain PyTorch implementation in FP16 and inputs/outputs
......@@ -1188,8 +1224,7 @@ class _dpa_fp8(torch.autograd.Function):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
with torch.cuda.nvtx.range("_DPA"):
(
inputmat_t,
qkv_weight_t_fp8,
......@@ -1298,6 +1333,7 @@ class _dpa_fp8(torch.autograd.Function):
None,
None)
class DPA_FP8(TransformerEngineBaseModule):
def __init__(
self,
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import List, Tuple
import pytest
import torch
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables,
MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
class ModelConfig:
def __init__(self, hidden_size, nheads, kv, seq_len):
self.h = hidden_size
self.nheads = nheads
self.kv = kv
self.s = seq_len
model_configs = {
"small": ModelConfig(64, 2, 32, 32),
}
modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]
optimizers = [torch.optim.SGD, torch.optim.Adam]
all_boolean = [True, False]
dtypes = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
failed = False
failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
with torch.no_grad():
t1.masked_fill_(t1.isnan(), 1.0)
t2.masked_fill_(t2.isnan(), 1.0)
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors
def generate_data(
s: int, b: int, h: int, nheads: int, kv: int, dtype: torch.dtype,
dpa: bool = False, warmup: bool = False, gen_labels: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
if dpa:
inputs = [gen_func(s, b, nheads, kv, device="cuda", requires_grad=True, dtype=dtype) for _ in range(3)]
else:
inputs = [gen_func(s, b, h, device="cuda", requires_grad=True, dtype=dtype)]
if not gen_labels:
return inputs
target = torch.randn(s, b, h, device="cuda", dtype=dtype)
return inputs, target
def get_outputs(model, output):
"""Return grads and params for comparsion."""
values = []
for param in model.parameters():
values.append(param)
if param.grad is not None:
values.append(param.grad)
values.append(output)
return values
def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, module, optimizer, graph_mode=""):
"""Helper function for test."""
reset_rng_states()
FP8GlobalStateManager.reset()
dpa = module == "dpa"
with fp8_model_init(enabled=fp8_params):
# Create modules.
if module == "transformer":
modules = [TransformerLayer(
config.h,
config.h,
config.nheads,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
) for _ in range(num_layers)]
elif module == "layernorm_mlp":
modules = [LayerNormMLP(
config.h, config.h, params_dtype=dtype
) for _ in range(num_layers)]
elif module == "layernorm_linear":
modules = [LayerNormLinear(
config.h, config.h, params_dtype=dtype
) for _ in range(num_layers)]
elif module == "mha":
modules = [MultiheadAttention(
config.h,
config.nheads,
attention_dropout=0.0,
params_dtype=dtype,
fuse_qkv_params=True,
) for _ in range(num_layers)]
elif dpa:
assert config.h % config.nheads == 0, "Err."
assert num_layers == 1, "Err."
modules = [DotProductAttention(
config.nheads, config.kv, attention_dropout=0.0
) for _ in range(num_layers)]
else:
modules = [Linear(
config.h, config.h, device="cuda", params_dtype=dtype
) for _ in range(num_layers)]
# Generate model and wrap API to return graphed version.
if graph:
# Graph entire module at once.
if graph_mode == "full":
model = modules[0] if dpa else torch.nn.Sequential(*modules)
model = make_graphed_callables(
model,
generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10,
fp8_enabled=fp8)
else:
modules = [make_graphed_callables(
module,
generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10,
fp8_enabled=fp8) for module in modules]
model = modules[0] if dpa else torch.nn.Sequential(*modules)
else:
model = modules[0] if dpa else torch.nn.Sequential(*modules)
# Loss function and optimizer.
loss_fn = torch.nn.MSELoss()
if not dpa:
optimizer = optimizer(model.parameters(), lr=0.001)
# Launch.
for _ in range(10):
inputs, target = generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, gen_labels=True)
with fp8_autocast(enabled=fp8):
output = model(*inputs)
loss = loss_fn(output, target)
loss.backward()
if not dpa:
optimizer.step()
optimizer.zero_grad()
return get_outputs(model, output)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("num_layers", [1, 10])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean)
@pytest.mark.parametrize("module", modules)
@pytest.mark.parametrize("optimizer", optimizers)
def test_gpt_make_graphed_callables(dtype, bs, model, num_layers, fp8, fp8_params, module, optimizer):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if module == "dpa" and num_layers > 1:
pytest.skip("Max 1 layer for DPA.")
config = model_configs[model]
outputs = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, False, module, optimizer)
graph_outputs_mode1 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="full")
graph_outputs_mode2 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="individual")
# Check that results match
assert_all_equal(outputs, graph_outputs_mode1)
assert_all_equal(outputs, graph_outputs_mode2)
......@@ -257,12 +257,10 @@ class TestFloat8Tensor:
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]])
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)])
@pytest.mark.parametrize("dims", [[33, 41], [7, 11]])
def test_transpose(
self,
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
......@@ -271,74 +269,44 @@ class TestFloat8Tensor:
# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
x,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
x = x_fp8.from_float8()
# Perform transpose
y_fp8 = x_fp8.transpose(*transpose_dims)
y_ref = x_ref.transpose(*transpose_dims)
x_fp8_t = x_fp8.transpose_2d()
x_t = x.transpose(0, 1)
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t)
# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(y_fp8, y_ref, **tols)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
# Make sure we are not trivially passing the test
if transpose_dims[0] != transpose_dims[1]:
with pytest.raises(AssertionError):
torch.testing.assert_close(
y_fp8,
x_ref,
**tols,
)
# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
torch.testing.assert_close(x_fp8_t, x, **tols)
# Check that cached transpose is returned when expected
# Note: Sneakily destroy data so that recalculating
# transpose would give wrong answer.
# Caching test.
assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8_data = x_fp8._data.clone()
x_fp8._data.zero_()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="force"),
torch.zeros_like(x_ref.transpose(*transpose_dims)),
rtol=0,
atol=0,
)
x_fp8._data.copy_(x_fp8_data)
x_fp8._reset_caches()
x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True))
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
# Make sure cache is reset after in-place operation
x_fp8.transpose(*transpose_dims, update_cache="force")
# Inplace update test.
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
x = x_fp8.from_float8()
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True))
x_t = x.transpose(0, 1)
torch.testing.assert_close(x_fp8_t, x_t, **tols)
assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
def test_serialization(
self,
......
......@@ -4,7 +4,6 @@
import math
import os
import sys
from typing import List, Optional
import pytest
import copy
......@@ -25,7 +24,6 @@ from transformer_engine.pytorch import (
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
# Only run FP8 tests on H100.
......@@ -54,6 +52,14 @@ model_configs = {
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
}
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
......@@ -104,7 +110,13 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
class TorchScaledMaskedSoftmax(nn.Module):
......@@ -373,10 +385,10 @@ class TorchGPT(nn.Module):
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a = self.ln(x)
b = self.causal_attn(a, attn_mask)
b = self.causal_attn(a, attention_mask)
if self.parallel_attention_mlp:
n = self.ln_mlp(x)
x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
......@@ -396,13 +408,6 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
......@@ -417,7 +422,6 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
......@@ -476,13 +480,6 @@ def _test_e2e_full_recompute(
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8 and fp8_model_params):
block = (
TransformerLayer(
......@@ -497,7 +494,6 @@ def _test_e2e_full_recompute(
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
......@@ -520,7 +516,6 @@ def _test_e2e_full_recompute(
checkpoint_core_attention=False,
distribute_saved_activations=False,
tp_group=None,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
use_reentrant=use_reentrant,
)
else:
......@@ -683,7 +678,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len)
out = block(inp_hidden_states, inp_attn_mask)
out = block(inp_hidden_states, attention_mask=inp_attn_mask)
loss = out.sum()
loss.backward()
......@@ -1261,13 +1256,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8_model_params):
block = (
TransformerLayer(
......@@ -1282,7 +1270,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
kv_channels=config.embed,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
fuse_qkv_params=True,
)
......@@ -1321,6 +1308,7 @@ def test_gpt_fp8_parameters(dtype, bs, model):
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
......@@ -1399,14 +1387,6 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
......
......@@ -86,6 +86,12 @@ def set_max_seq_len(max_seq_len=128):
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}"
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
......
......@@ -48,6 +48,7 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe."""
return torch.min(amax_history, dim=0).values
@dataclass
class ModelConfig:
"""Transformer model configuration"""
......@@ -115,6 +116,12 @@ def _disable_wgrads(block):
p.requires_grad = False
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
......@@ -137,7 +144,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
with torch.cuda.stream(s):
for _ in range(3):
optimizer.zero_grad(set_to_none=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
out = block(static_input)
loss = loss_fn(out, static_target)
loss.backward()
......@@ -148,7 +155,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
static_output = block(static_input)
static_loss = loss_fn(static_output, static_target)
static_loss.backward()
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file transpose_with_noop.h
* \brief Functions handling transposes with no-op.
*/
#ifndef TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_
#define TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
void nvte_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor output,
cudaStream_t stream);
void nvte_cast_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_
......@@ -56,6 +56,45 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his
float margin,
cudaStream_t stream);
/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction.
*
* Operations performed include, updating the most recent amax history
* with the relevant segment of global reduction buffer if it's not 0,
* rotating the amax history based on the rule below, and updating the
* scales and scale_invs.
*
* The amax history is rotated by -1 (e.g. the first entry shifts to
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$.
*
* \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction.
* Shape: [num_scales * num_tensors]
* \param[in,out] amax_histories List of amax histories of maximum absolute values.
* Shape: num_tensors x [history_length, num_scales]
* \param[in,out] scales List of scaling factors for casting to FP8.
* Shape: num_tensors x [num_scales]
* \param[in,out] scale_invs List of scaling factors for casting from FP8.
* Shape: num_tensors x [num_scales]
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and
* "most_recent".
* \param[in] fp8_dtype FP8 datatype.
* \param[in] margin Scaling factor margin.
* \param[in] stream CUDA stream.
*/
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer,
std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales,
std::vector<NVTETensor> scale_invs,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -229,19 +229,29 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = layer_norm::DType::kByte;
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
workspace->data.shape = { launch_params.workspace_bytes };
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = { launch_params.barrier_size };
return;
} else {
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
......@@ -368,6 +378,27 @@ void layernorm_bwd(const Tensor& dz,
barrier->data.shape = { launch_params.barrier_size };
return;
} else {
NVTE_CHECK(dbeta_part->data.dptr != nullptr);
auto pdw_shape = std::vector<size_t>{
static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
NVTE_CHECK(dbeta_part->data.dtype == ctype);
NVTE_CHECK(dbeta_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
......
......@@ -133,3 +133,13 @@ class DelayedScaling:
(False, False, False),
(False, False, True),
), "Only wgrad GEMM override is currently supported."
def __repr__(self) -> str:
return (
f"margin={self.margin}, "
f"interval={self.interval}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"wgrad_override={self.override_linear_precision.wgrad}, "
f"reduce_amax={self.reduce_amax}"
)
......@@ -11,6 +11,7 @@
#include "../common.h"
#include "../util/logging.h"
#include "../util/cuda_runtime.h"
namespace transformer_engine {
namespace delayed_scaling_recipe {
......@@ -38,6 +39,36 @@ inline float fp8_dtype_max(DType dtype) {
return 0;
}
// struct for amax parameters
struct AmaxParam {
int num_scale = 0;
float* amax_history = nullptr;
float* scale = nullptr;
float* scale_inv = nullptr;
};
// dummy struct for kernel_bulk's other params
struct OtherParams {
float* a;
size_t b;
AmaxComputeAlgo c;
float d;
};
#if CUDART_VERSION >= 12010
constexpr size_t max_constant_memory_per_kernel = 32000;
constexpr size_t AMAX_PARAMS_LIMIT = (
max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
#else
constexpr size_t max_constant_memory_per_kernel = 4000;
constexpr size_t AMAX_PARAMS_LIMIT = (
max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
#endif
struct AmaxParams {
AmaxParam param[AMAX_PARAMS_LIMIT];
};
namespace amax_and_scale_update_impl {
// CUDA block size
......@@ -133,11 +164,96 @@ kernel(const float* amax_history_ptr,
}
}
} // namespace amax_and_scale_update_impl
/* CUDA kernel to bulk-update amax history and FP8 scaling factors
*
* Block dims: bsize x 1 x 1
*
* Grid dims: num_tensors x 1 x 1
*/
__global__ void __launch_bounds__(bsize)
kernel_bulk(
float* amax_reduction_buffer,
AmaxParams p,
size_t amax_history_length,
AmaxComputeAlgo amax_compute_algo,
float scaled_max) {
const size_t bid = blockIdx.x;
const size_t tid = threadIdx.x;
const int num_scale = p.param[bid].num_scale;
int offset_in_buffer = 0;
for (int j = 0; j < bid; j++) {
offset_in_buffer += p.param[j].num_scale;
}
for (int count = 0; count < num_scale; count++) {
// Update amax
float amax = 0;
{
// Roll amax history
const auto& length = amax_history_length;
const auto& stride = p.param[bid].num_scale;
auto* amax_history = p.param[bid].amax_history+count;
const auto last_amax = ((amax_reduction_buffer != nullptr)
&& (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ?
amax_reduction_buffer[offset_in_buffer+count] : amax_history[0];
for (size_t off = 0; off < length; off += bsize) {
const size_t i = off + tid;
float a = 0;
if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
amax = fmaxf(amax, a);
}
__syncthreads(); // Inplace roll
if (i < length) {
amax_history[i*stride] = (i > 0) ? a : 0;
}
}
// Compute amax to use for scaling factor
switch (amax_compute_algo) {
case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax;
break;
case AmaxComputeAlgo::MAX:
{
__shared__ float shared_amax[bsize];
shared_amax[tid] = amax;
__syncthreads();
#pragma unroll
for (size_t off = bsize / 2; off > 0; off /= 2) {
if (tid < off) {
shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
}
__syncthreads();
}
amax = shared_amax[tid];
}
break;
default:
amax = 0;
}
}
// Update scale and scale inverse
if (tid == 0) {
float scale;
if (isfinite(amax) && amax > 0) {
scale = scaled_max / amax;
} else {
scale = p.param[bid].scale[count];
}
p.param[bid].scale[count] = scale;
p.param[bid].scale_inv[count] = 1 / scale;
}
}
}
} // namespace amax_and_scale_update_impl
} // namespace
void amax_and_scale_update(const Tensor &amax_history,
const Tensor &scale,
const Tensor &scale_inv,
......@@ -238,9 +354,105 @@ void amax_and_scale_update(const Tensor &amax_history,
NVTE_CHECK_CUDA(cudaGetLastError());
}
void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
std::vector<Tensor*> amax_histories,
std::vector<Tensor*> scales,
std::vector<Tensor*> scale_invs,
const std::string &amax_compute_algo,
DType fp8_dtype,
float margin,
cudaStream_t stream) {
using namespace transformer_engine;
// amax value to use for updating scaling factor
AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
if (amax_compute_algo == "max") {
amax_compute_algo_ = AmaxComputeAlgo::MAX;
} else if (amax_compute_algo == "most_recent") {
amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
} else {
NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
}
// Expected maximum value after scale is applied
const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);
// Number of elements in tensor
auto numel = [] (const Tensor *tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor->data.shape) {
acc *= dim;
}
return acc;
};
// Number of tensors in the bulk
const size_t num_tensors = amax_histories.size();
const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT;
size_t amax_history_length = 0;
if (num_tensors > 0) {
amax_history_length = amax_histories[0]->data.shape[0];
}
// amax parameters
float* amax_buffer = static_cast<float*>(amax_reduction_buffer.data.dptr);
AmaxParams p;
for (int iter = 0; iter < num_kernels; iter++) {
size_t kernel_num_scales = 0;
size_t kernel_num_tensors = (iter == (num_kernels -1))
? num_tensors % AMAX_PARAMS_LIMIT: AMAX_PARAMS_LIMIT;
for (size_t pi = 0; pi < kernel_num_tensors; pi++) {
size_t i = iter * AMAX_PARAMS_LIMIT + pi;
// Check tensors
int num_scale = amax_histories[i]->data.shape[1];
NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32,
"Found ", dtype_name(amax_histories[i]->data.dtype), ".");
NVTE_CHECK(amax_histories[i]->data.shape.size() == 2,
"Found ", amax_histories[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale,
"Expected ", amax_history_length * num_scale, " elements, ",
"but found ", numel(amax_histories[i]), ".");
NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32,
"Found ", dtype_name(scales[i]->data.dtype), ".");
NVTE_CHECK(scales[i]->data.shape.size() == 1,
"Found ", scales[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(scales[i]) == num_scale,
"Expected ", num_scale, " elements, ",
"Found ", numel(scales[i]), ".");
// amax parameters
kernel_num_scales += num_scale;
p.param[pi].num_scale = num_scale;
p.param[pi].amax_history = static_cast<float*>(amax_histories[i]->data.dptr);
p.param[pi].scale = static_cast<float*>(scales[i]->data.dptr);
p.param[pi].scale_inv = static_cast<float*>(scale_invs[i]->data.dptr);
}
// Launch CUDA kernel
size_t grid_size = kernel_num_tensors;
const size_t block_size = amax_and_scale_update_impl::bsize;
amax_and_scale_update_impl::kernel_bulk
<<<grid_size, block_size, 0, stream>>>(
amax_buffer,
p,
amax_history_length,
amax_compute_algo_,
scaled_max);
NVTE_CHECK_CUDA(cudaGetLastError());
// shift amax buffer pointer
if (amax_buffer != nullptr) {
amax_buffer += kernel_num_scales;
}
}
}
} // namespace delayed_scaling_recipe
} // namespace transformer_engine
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history,
const NVTETensor scale,
const NVTETensor scale_inv,
......@@ -267,3 +479,33 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his
margin,
stream);
}
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer,
std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales,
std::vector<NVTETensor> scale_invs,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction);
using namespace transformer_engine;
size_t num_tensors = amax_histories.size();
std::vector<Tensor*> t_amax_histories, t_scales, t_scale_invs;
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i]));
t_scales.push_back(reinterpret_cast<Tensor*>(scales[i]));
t_scale_invs.push_back(reinterpret_cast<Tensor*>(scale_invs[i]));
}
delayed_scaling_recipe::amax_and_scale_update_after_reduction(
*reinterpret_cast<const Tensor*>(amax_reduction_buffer),
t_amax_histories,
t_scales,
t_scale_invs,
amax_compute_algo,
static_cast<DType>(fp8_dtype),
margin,
stream);
}
......@@ -153,21 +153,32 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = DType::kByte;
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
......@@ -265,6 +276,23 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
barrier->data.shape = {launch_params.barrier_size};
return;
} else {
auto pdw_shape = std::vector<size_t>{
static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
// Tensor checks are delayed here in order to recover workspace sizes with null data
......
......@@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <iostream>
......@@ -56,6 +57,7 @@ template <int nvec_in, int nvec_out, typename CType, typename IType, typename OT
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_kernel(const IType * const input,
const CType * const noop,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
......@@ -63,6 +65,8 @@ cast_transpose_kernel(const IType * const input,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
if (noop != nullptr && noop[0] == 1.0f) return;
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
......@@ -163,6 +167,7 @@ template <int nvec_in, int nvec_out, typename CType, typename IType, typename OT
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_kernel_notaligned(const IType * const input,
const CType * const noop,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
......@@ -170,6 +175,8 @@ cast_transpose_kernel_notaligned(const IType * const input,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
if (noop != nullptr && noop[0] == 1.0f) return;
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
......@@ -294,6 +301,7 @@ cast_transpose_kernel_notaligned(const IType * const input,
}
void cast_transpose(const Tensor &input,
const Tensor &noop,
Tensor *cast_output,
Tensor *transposed_output,
cudaStream_t stream) {
......@@ -301,6 +309,22 @@ void cast_transpose(const Tensor &input,
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
// Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1,
"Expected 1 element, ",
"but found ", numel(noop), ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr);
}
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
......@@ -332,6 +356,7 @@ void cast_transpose(const Tensor &input,
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), \
stream>>>( \
reinterpret_cast<const InputType *>(input.data.dptr), \
reinterpret_cast<const fp32 *>(noop.data.dptr), \
reinterpret_cast<OutputType *>(cast_output->data.dptr), \
reinterpret_cast<OutputType *>(transposed_output->data.dptr), \
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), \
......@@ -417,7 +442,23 @@ void nvte_cast_transpose(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine;
auto noop = Tensor();
cast_transpose(*reinterpret_cast<const Tensor*>(input),
noop,
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
void nvte_cast_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(noop),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
......
......@@ -22,9 +22,12 @@ constexpr size_t block_size = __BLOCK_SIZE__;
__global__ void
__launch_bounds__(block_size)
transpose_optimized_kernel(const Type * __restrict__ const input,
const float * const noop,
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(Type);
constexpr size_t nvec_out = store_size / sizeof(Type);
......
......@@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <iostream>
......@@ -30,9 +31,12 @@ template <size_t load_size, size_t store_size, typename Type>
__global__ void
__launch_bounds__(block_size)
transpose_general_kernel(const Type * __restrict__ const input,
const fp32 * const noop,
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(Type);
constexpr size_t nvec_out = store_size / sizeof(Type);
......@@ -124,6 +128,7 @@ transpose_general_kernel(const Type * __restrict__ const input,
}
void transpose(const Tensor &input,
const Tensor &noop,
Tensor *output_,
cudaStream_t stream) {
Tensor &output = *output_;
......@@ -140,6 +145,23 @@ void transpose(const Tensor &input,
NVTE_CHECK(input.data.dtype == output.data.dtype,
"Input and output type must match.");
// Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1,
"Expected 1 element, ",
"but found ", numel(noop), ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type,
constexpr const char *type_name = TypeInfo<Type>::name;
constexpr size_t type_size = sizeof(Type);
......@@ -239,6 +261,7 @@ void transpose(const Tensor &input,
rtc_manager.launch(kernel_label,
num_blocks(load_size, store_size), block_size, 0, stream,
static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type*>(output.data.dptr),
row_length, num_rows);
} else { // Statically-compiled general kernel
......@@ -250,6 +273,7 @@ void transpose(const Tensor &input,
* DIVUP(num_rows, col_tile_size));
transpose_general_kernel<load_size, store_size, Type><<<num_blocks, block_size, 0, stream>>>(
static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type *>(output.data.dptr),
row_length, num_rows);
}
......@@ -263,7 +287,22 @@ void nvte_transpose(const NVTETensor input,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine;
auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor*>(input),
noop,
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(noop),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -14,6 +14,7 @@ from .attention import MultiheadAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .fp8 import fp8_model_init
from .graph import make_graphed_callables
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
......
......@@ -52,9 +52,14 @@ from transformer_engine.pytorch.distributed import (
get_distributed_world_size,
get_distributed_rank,
checkpoint,
set_all_rng_states,
CudaRNGStatesTracker,
graph_safe_rng_available,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6")
......@@ -2401,10 +2406,13 @@ class DotProductAttention(torch.nn.Module):
assert (num_attention_heads % self.num_gqa_groups == 0
), "The number of attention heads must be divisible by the number of GQA groups!"
self.rng_states_tracker = None
if sequence_parallel or get_rng_state_tracker is None:
attention_dropout_ctx = nullcontext
else:
attention_dropout_ctx = get_rng_state_tracker().fork
self.rng_states_tracker = get_rng_state_tracker()
set_all_rng_states(self.rng_states_tracker.get_states())
attention_dropout_ctx = self.rng_states_tracker.fork
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
......@@ -2648,6 +2656,14 @@ class DotProductAttention(torch.nn.Module):
assert (attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if self.rng_states_tracker is not None and is_graph_capturing():
assert (
isinstance(self.rng_states_tracker, CudaRNGStatesTracker)
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if window_size is None:
window_size = self.window_size
......@@ -3695,7 +3711,8 @@ class MultiheadAttention(torch.nn.Module):
# ===================
projection_output = self.proj(
context_layer, is_first_microbatch=is_first_microbatch
context_layer,
is_first_microbatch=is_first_microbatch,
)
if self.return_bias:
......
......@@ -22,19 +22,26 @@ def fp8_cast_transpose_fused(
otype: tex.DType,
cast_out: Optional[torch.Tensor] = None,
transpose_out: Optional[torch.Tensor] = None,
noop_flag: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
"""Cast + Transpose with FP8 output"""
return_outputs = False
if cast_out is None or transpose_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8)
if transpose_out is None:
transpose_out = torch.empty(
inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8
)
return_outputs = True
if cast_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8)
return_outputs = True
if noop_flag is None:
noop_flag = torch.Tensor()
tex.fused_cast_transpose(
tex.fused_cast_transpose_noop(
inp,
noop_flag,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment