Unverified Commit 73f8d90f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] cuda graph support (#575)



* FP8 cuda graphs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix numerics
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* exclude torch compile from numerics tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More numerics fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm fusion from unfused path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
parent 1b20f2d6
...@@ -41,4 +41,6 @@ pyTorch ...@@ -41,4 +41,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.onnx_export .. autoapifunction:: transformer_engine.pytorch.onnx_export
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
...@@ -9,9 +9,10 @@ set -e ...@@ -9,9 +9,10 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import functools import functools
from importlib.metadata import version from importlib.metadata import version
import os import os
import math
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging from pkg_resources import packaging
...@@ -28,15 +27,9 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -28,15 +27,9 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_bwd, fused_attn_bwd,
fused_attn_fwd, fused_attn_fwd,
) )
from transformer_engine.pytorch.distributed import ( from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
_set_cuda_rng_state,
CudaRNGStatesTracker,
)
import transformer_engine.pytorch.fp8 as fp8 import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import ( from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
TransformerEngineBaseModule,
_prepare_backward,
)
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
init_method_normal, init_method_normal,
...@@ -58,10 +51,18 @@ _cuda_rng_state = torch.cuda.get_rng_state() ...@@ -58,10 +51,18 @@ _cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""Revert back to initial RNG state""" """Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state) torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
fp8.FP8GlobalStateManager.reset()
@functools.cache @functools.cache
def _cudnn_version() -> Tuple[int, int, int]: def _cudnn_version() -> Tuple[int, int, int]:
...@@ -71,6 +72,7 @@ def _cudnn_version() -> Tuple[int, int, int]: ...@@ -71,6 +72,7 @@ def _cudnn_version() -> Tuple[int, int, int]:
minor, patch = divmod(encoded_version, 100) minor, patch = divmod(encoded_version, 100)
return (major, minor, patch) return (major, minor, patch)
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
...@@ -103,6 +105,7 @@ class ModelConfig: ...@@ -103,6 +105,7 @@ class ModelConfig:
self.num_layers = num_layers self.num_layers = num_layers
self.bias_shape = bias_shape self.bias_shape = bias_shape
def _is_fused_attention_supported( def _is_fused_attention_supported(
config: ModelConfig, config: ModelConfig,
dtype: torch.dtype, dtype: torch.dtype,
...@@ -151,24 +154,28 @@ def _is_fused_attention_supported( ...@@ -151,24 +154,28 @@ def _is_fused_attention_supported(
return True, backends return True, backends
return False, backends return False, backends
@functools.cache @functools.cache
def _is_flash_attention_2_available() -> bool: def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available""" """Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2") return Version(version("flash-attn")) >= Version("2")
@functools.cache @functools.cache
def _is_flash_attention_2_1() -> bool: def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available""" """Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1") return Version(version("flash-attn")) >= Version("2.1")
@functools.cache @functools.cache
def _is_flash_attention_2_3() -> bool: def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available""" """Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3") return Version(version("flash-attn")) >= Version("2.3")
def _is_flash_attention_supported(config: ModelConfig) -> bool: def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration""" """Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0): if get_device_compute_capability() < (8, 0):
...@@ -184,6 +191,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool: ...@@ -184,6 +191,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False return False
return True return True
def _is_unfused_attention_supported(config: ModelConfig) -> bool: def _is_unfused_attention_supported(config: ModelConfig) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration""" """Check if UnfusedDotProductAttention supports a model configuration"""
if ("padding" in config.attn_mask_type): if ("padding" in config.attn_mask_type):
...@@ -192,6 +200,7 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool: ...@@ -192,6 +200,7 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool:
return False return False
return True return True
model_configs_base = { model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
...@@ -200,11 +209,13 @@ model_configs_base = { ...@@ -200,11 +209,13 @@ model_configs_base = {
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
} }
param_types = [torch.float16] param_types = [torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16] param_types_lean = [torch.bfloat16]
def get_swa(seq_q, seq_kv, w=None): def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None, """Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape""" and create its equivalent attention mask in [seq_q, seq_kv] shape"""
...@@ -216,6 +227,7 @@ def get_swa(seq_q, seq_kv, w=None): ...@@ -216,6 +227,7 @@ def get_swa(seq_q, seq_kv, w=None):
ml = ~ ml ml = ~ ml
return w, ml return w, ml
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
...@@ -313,6 +325,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ...@@ -313,6 +325,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
for i,_ in enumerate(fused_attn_bwd): for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base]) @pytest.mark.parametrize("model_configs", [model_configs_base])
...@@ -321,6 +334,7 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -321,6 +334,7 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing""" """Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False) test_dot_product_attention(dtype, model_configs, model, True, True, None, False)
model_configs_mask = { model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
...@@ -337,6 +351,7 @@ model_configs_mask = { ...@@ -337,6 +351,7 @@ model_configs_mask = {
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask]) @pytest.mark.parametrize("model_configs", [model_configs_mask])
...@@ -345,6 +360,7 @@ def test_dpa_mask(dtype, model_configs, model): ...@@ -345,6 +360,7 @@ def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types""" """Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False) test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_bias = { model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
...@@ -373,6 +389,7 @@ model_configs_bias = { ...@@ -373,6 +389,7 @@ model_configs_bias = {
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias]) @pytest.mark.parametrize("model_configs", [model_configs_bias])
...@@ -381,6 +398,7 @@ def test_dpa_bias(dtype, model_configs, model): ...@@ -381,6 +398,7 @@ def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types""" """Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False) test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
model_configs_bias_shapes = { model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p, # test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
...@@ -398,6 +416,7 @@ model_configs_bias_shapes = { ...@@ -398,6 +416,7 @@ model_configs_bias_shapes = {
"causal", "alibi", bias_shape='bhss', alibi_type='custom'), "causal", "alibi", bias_shape='bhss', alibi_type='custom'),
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes]) @pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
...@@ -413,6 +432,8 @@ model_configs_swa = { ...@@ -413,6 +432,8 @@ model_configs_swa = {
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
} }
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model_configs", [model_configs_swa])
...@@ -428,6 +449,8 @@ model_configs_alibi_slopes = { ...@@ -428,6 +449,8 @@ model_configs_alibi_slopes = {
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"), "alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"), "alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
} }
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
...@@ -436,6 +459,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): ...@@ -436,6 +459,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes""" """Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False) test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
qkv_layouts = [ qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd', 'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
...@@ -443,6 +467,7 @@ qkv_layouts = [ ...@@ -443,6 +467,7 @@ qkv_layouts = [
#'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd', #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
] ]
model_configs_layout = { model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
...@@ -455,6 +480,7 @@ model_configs_layout = { ...@@ -455,6 +480,7 @@ model_configs_layout = {
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout]) @pytest.mark.parametrize("model_configs", [model_configs_layout])
...@@ -464,6 +490,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): ...@@ -464,6 +490,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts""" """Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False) test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)
def _run_dot_product_attention( def _run_dot_product_attention(
dtype: torch.dtype, dtype: torch.dtype,
config: ModelConfig, config: ModelConfig,
...@@ -646,6 +673,7 @@ def _run_dot_product_attention( ...@@ -646,6 +673,7 @@ def _run_dot_product_attention(
return out, (inp[0].grad, inp[1].grad, inp[2].grad) return out, (inp[0].grad, inp[1].grad, inp[2].grad)
model_configs_te_layer = { model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
...@@ -658,6 +686,7 @@ model_configs_te_layer = { ...@@ -658,6 +686,7 @@ model_configs_te_layer = {
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
...@@ -742,6 +771,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f ...@@ -742,6 +771,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
...@@ -755,6 +785,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format): ...@@ -755,6 +785,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format):
test_transformer_layer(dtype, model_configs, model, test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE) ckpt_attn, qkv_format, fused_qkv_params, RoPE)
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model_configs", [model_configs_te_layer])
...@@ -780,6 +811,7 @@ def test_te_layer_mqa_gqa(dtype, model_configs, model): ...@@ -780,6 +811,7 @@ def test_te_layer_mqa_gqa(dtype, model_configs, model):
test_transformer_layer(dtype, model_configs, model, test_transformer_layer(dtype, model_configs, model,
ckpt_attn, qkv_format, fused_qkv_params, RoPE) ckpt_attn, qkv_format, fused_qkv_params, RoPE)
def _run_transformer_layer( def _run_transformer_layer(
dtype: torch.dtype, dtype: torch.dtype,
config: ModelConfig, config: ModelConfig,
...@@ -912,8 +944,10 @@ model_configs_fp8 = { ...@@ -912,8 +944,10 @@ model_configs_fp8 = {
"fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
} }
param_types_fp8 = [torch.float16] param_types_fp8 = [torch.float16]
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
...@@ -946,6 +980,7 @@ def test_dpa_fp8(dtype, model): ...@@ -946,6 +980,7 @@ def test_dpa_fp8(dtype, model):
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
def _run_dpa_fp8(dtype, config, backend): def _run_dpa_fp8(dtype, config, backend):
"""Run FusedAttention FP8 backend, i.e. """Run FusedAttention FP8 backend, i.e.
fused_attn_fwd/bwd_qkvpacked from cpp_extensions""" fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
...@@ -989,6 +1024,7 @@ def _run_dpa_fp8(dtype, config, backend): ...@@ -989,6 +1024,7 @@ def _run_dpa_fp8(dtype, config, backend):
dqkv.view(config.batch_size, config.max_seqlen_q, 3, dqkv.view(config.batch_size, config.max_seqlen_q, 3,
config.num_heads, config.head_dim).transpose(0,1).contiguous()) config.num_heads, config.head_dim).transpose(0,1).contiguous())
def _run_dpa_fp8_ref(dtype, config, backend): def _run_dpa_fp8_ref(dtype, config, backend):
"""Run UnfusedDotProductAttention as a reference, i.e. """Run UnfusedDotProductAttention as a reference, i.e.
plain PyTorch implementation in FP16 and inputs/outputs plain PyTorch implementation in FP16 and inputs/outputs
...@@ -1188,8 +1224,7 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1188,8 +1224,7 @@ class _dpa_fp8(torch.autograd.Function):
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
( (
inputmat_t, inputmat_t,
qkv_weight_t_fp8, qkv_weight_t_fp8,
...@@ -1298,6 +1333,7 @@ class _dpa_fp8(torch.autograd.Function): ...@@ -1298,6 +1333,7 @@ class _dpa_fp8(torch.autograd.Function):
None, None,
None) None)
class DPA_FP8(TransformerEngineBaseModule): class DPA_FP8(TransformerEngineBaseModule):
def __init__( def __init__(
self, self,
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import List, Tuple
import pytest
import torch
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables,
MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
class ModelConfig:
def __init__(self, hidden_size, nheads, kv, seq_len):
self.h = hidden_size
self.nheads = nheads
self.kv = kv
self.s = seq_len
model_configs = {
"small": ModelConfig(64, 2, 32, 32),
}
modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]
optimizers = [torch.optim.SGD, torch.optim.Adam]
all_boolean = [True, False]
dtypes = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16)
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
failed = False
failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
with torch.no_grad():
t1.masked_fill_(t1.isnan(), 1.0)
t2.masked_fill_(t2.isnan(), 1.0)
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors
def generate_data(
s: int, b: int, h: int, nheads: int, kv: int, dtype: torch.dtype,
dpa: bool = False, warmup: bool = False, gen_labels: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
if dpa:
inputs = [gen_func(s, b, nheads, kv, device="cuda", requires_grad=True, dtype=dtype) for _ in range(3)]
else:
inputs = [gen_func(s, b, h, device="cuda", requires_grad=True, dtype=dtype)]
if not gen_labels:
return inputs
target = torch.randn(s, b, h, device="cuda", dtype=dtype)
return inputs, target
def get_outputs(model, output):
"""Return grads and params for comparsion."""
values = []
for param in model.parameters():
values.append(param)
if param.grad is not None:
values.append(param.grad)
values.append(output)
return values
def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, module, optimizer, graph_mode=""):
"""Helper function for test."""
reset_rng_states()
FP8GlobalStateManager.reset()
dpa = module == "dpa"
with fp8_model_init(enabled=fp8_params):
# Create modules.
if module == "transformer":
modules = [TransformerLayer(
config.h,
config.h,
config.nheads,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
) for _ in range(num_layers)]
elif module == "layernorm_mlp":
modules = [LayerNormMLP(
config.h, config.h, params_dtype=dtype
) for _ in range(num_layers)]
elif module == "layernorm_linear":
modules = [LayerNormLinear(
config.h, config.h, params_dtype=dtype
) for _ in range(num_layers)]
elif module == "mha":
modules = [MultiheadAttention(
config.h,
config.nheads,
attention_dropout=0.0,
params_dtype=dtype,
fuse_qkv_params=True,
) for _ in range(num_layers)]
elif dpa:
assert config.h % config.nheads == 0, "Err."
assert num_layers == 1, "Err."
modules = [DotProductAttention(
config.nheads, config.kv, attention_dropout=0.0
) for _ in range(num_layers)]
else:
modules = [Linear(
config.h, config.h, device="cuda", params_dtype=dtype
) for _ in range(num_layers)]
# Generate model and wrap API to return graphed version.
if graph:
# Graph entire module at once.
if graph_mode == "full":
model = modules[0] if dpa else torch.nn.Sequential(*modules)
model = make_graphed_callables(
model,
generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10,
fp8_enabled=fp8)
else:
modules = [make_graphed_callables(
module,
generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10,
fp8_enabled=fp8) for module in modules]
model = modules[0] if dpa else torch.nn.Sequential(*modules)
else:
model = modules[0] if dpa else torch.nn.Sequential(*modules)
# Loss function and optimizer.
loss_fn = torch.nn.MSELoss()
if not dpa:
optimizer = optimizer(model.parameters(), lr=0.001)
# Launch.
for _ in range(10):
inputs, target = generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, gen_labels=True)
with fp8_autocast(enabled=fp8):
output = model(*inputs)
loss = loss_fn(output, target)
loss.backward()
if not dpa:
optimizer.step()
optimizer.zero_grad()
return get_outputs(model, output)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("num_layers", [1, 10])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean)
@pytest.mark.parametrize("module", modules)
@pytest.mark.parametrize("optimizer", optimizers)
def test_gpt_make_graphed_callables(dtype, bs, model, num_layers, fp8, fp8_params, module, optimizer):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if module == "dpa" and num_layers > 1:
pytest.skip("Max 1 layer for DPA.")
config = model_configs[model]
outputs = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, False, module, optimizer)
graph_outputs_mode1 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="full")
graph_outputs_mode2 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="individual")
# Check that results match
assert_all_equal(outputs, graph_outputs_mode1)
assert_all_equal(outputs, graph_outputs_mode2)
...@@ -257,12 +257,10 @@ class TestFloat8Tensor: ...@@ -257,12 +257,10 @@ class TestFloat8Tensor:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols) torch.testing.assert_close(x_fp8, x_ref, **tols)
@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) @pytest.mark.parametrize("dims", [[33, 41], [7, 11]])
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)])
def test_transpose( def test_transpose(
self, self,
dims: DimsType, dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5, scale: float = 0.5,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
...@@ -271,74 +269,44 @@ class TestFloat8Tensor: ...@@ -271,74 +269,44 @@ class TestFloat8Tensor:
# Initialize random data # Initialize random data
dims = _to_list(dims) dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8( x_fp8 = Float8Tensor.to_float8(
x_ref, x,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
scale=torch.full([1], scale), scale=torch.full([1], scale),
) )
x_ref = x_fp8.from_float8() x = x_fp8.from_float8()
# Perform transpose # Perform transpose
y_fp8 = x_fp8.transpose(*transpose_dims) x_fp8_t = x_fp8.transpose_2d()
y_ref = x_ref.transpose(*transpose_dims) x_t = x.transpose(0, 1)
x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t)
# Check results # Check results
tols = dict(rtol=0, atol=0) tols = dict(rtol=0, atol=0)
torch.testing.assert_close(y_fp8, y_ref, **tols) torch.testing.assert_close(x_fp8_t, x_t, **tols)
# Make sure we are not trivially passing the test # Make sure we are not trivially passing the test
if transpose_dims[0] != transpose_dims[1]:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close( torch.testing.assert_close(x_fp8_t, x, **tols)
y_fp8,
x_ref,
**tols,
)
# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
# Check that cached transpose is returned when expected # Caching test.
# Note: Sneakily destroy data so that recalculating assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching."
# transpose would give wrong answer.
x_fp8 += 0.5 x_fp8 += 0.5
x_ref = x_fp8.from_float8() x = x_fp8.from_float8()
torch.testing.assert_close( x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True))
x_fp8.transpose(*transpose_dims, update_cache="lazy"), x_t = x.transpose(0, 1)
x_ref.transpose(*transpose_dims), torch.testing.assert_close(x_fp8_t, x_t, **tols)
**tols, assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
)
x_fp8_data = x_fp8._data.clone()
x_fp8._data.zero_()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="force"),
torch.zeros_like(x_ref.transpose(*transpose_dims)),
rtol=0,
atol=0,
)
x_fp8._data.copy_(x_fp8_data)
x_fp8._reset_caches()
# Make sure cache is reset after in-place operation # Inplace update test.
x_fp8.transpose(*transpose_dims, update_cache="force")
x_fp8 += 0.5 x_fp8 += 0.5
x_ref = x_fp8.from_float8() assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly."
torch.testing.assert_close( x = x_fp8.from_float8()
x_fp8.transpose(*transpose_dims), x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True))
x_ref.transpose(*transpose_dims), x_t = x.transpose(0, 1)
**tols, torch.testing.assert_close(x_fp8_t, x_t, **tols)
) assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly."
def test_serialization( def test_serialization(
self, self,
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import math import math
import os import os
import sys
from typing import List, Optional from typing import List, Optional
import pytest import pytest
import copy import copy
...@@ -25,7 +24,6 @@ from transformer_engine.pytorch import ( ...@@ -25,7 +24,6 @@ from transformer_engine.pytorch import (
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
...@@ -54,6 +52,14 @@ model_configs = { ...@@ -54,6 +52,14 @@ model_configs = {
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
} }
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
param_types = [torch.float32, torch.float16] param_types = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
...@@ -104,7 +110,13 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) ...@@ -104,7 +110,13 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float)
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""revert back to initial RNG state.""" """revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state) torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
class TorchScaledMaskedSoftmax(nn.Module): class TorchScaledMaskedSoftmax(nn.Module):
...@@ -373,10 +385,10 @@ class TorchGPT(nn.Module): ...@@ -373,10 +385,10 @@ class TorchGPT(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
a = self.ln(x) a = self.ln(x)
b = self.causal_attn(a, attn_mask) b = self.causal_attn(a, attention_mask)
if self.parallel_attention_mlp: if self.parallel_attention_mlp:
n = self.ln_mlp(x) n = self.ln_mlp(x)
x = x + nn.functional.dropout(b + n, p=0.1, training=self.training) x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
...@@ -396,13 +408,6 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False ...@@ -396,13 +408,6 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8 and fp8_model_params): with fp8_model_init(enabled=fp8 and fp8_model_params):
block = ( block = (
TransformerLayer( TransformerLayer(
...@@ -417,7 +422,6 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False ...@@ -417,7 +422,6 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
) )
...@@ -476,13 +480,6 @@ def _test_e2e_full_recompute( ...@@ -476,13 +480,6 @@ def _test_e2e_full_recompute(
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8 and fp8_model_params): with fp8_model_init(enabled=fp8 and fp8_model_params):
block = ( block = (
TransformerLayer( TransformerLayer(
...@@ -497,7 +494,6 @@ def _test_e2e_full_recompute( ...@@ -497,7 +494,6 @@ def _test_e2e_full_recompute(
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
) )
...@@ -520,7 +516,6 @@ def _test_e2e_full_recompute( ...@@ -520,7 +516,6 @@ def _test_e2e_full_recompute(
checkpoint_core_attention=False, checkpoint_core_attention=False,
distribute_saved_activations=False, distribute_saved_activations=False,
tp_group=None, tp_group=None,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
use_reentrant=use_reentrant, use_reentrant=use_reentrant,
) )
else: else:
...@@ -683,7 +678,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): ...@@ -683,7 +678,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) inp_attn_mask = get_causal_attn_mask(config.seq_len)
out = block(inp_hidden_states, inp_attn_mask) out = block(inp_hidden_states, attention_mask=inp_attn_mask)
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
...@@ -1261,13 +1256,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): ...@@ -1261,13 +1256,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
with fp8_model_init(enabled=fp8_model_params): with fp8_model_init(enabled=fp8_model_params):
block = ( block = (
TransformerLayer( TransformerLayer(
...@@ -1282,7 +1270,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): ...@@ -1282,7 +1270,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
kv_channels=config.embed, kv_channels=config.embed,
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype, params_dtype=dtype,
fuse_qkv_params=True, fuse_qkv_params=True,
) )
...@@ -1321,6 +1308,7 @@ def test_gpt_fp8_parameters(dtype, bs, model): ...@@ -1321,6 +1308,7 @@ def test_gpt_fp8_parameters(dtype, bs, model):
outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True)
assert_all_equal(outputs, outputs_fp8_params) assert_all_equal(outputs, outputs_fp8_params)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
...@@ -1399,14 +1387,6 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -1399,14 +1387,6 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()]) assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys()) @pytest.mark.parametrize("model_key", model_configs_inference.keys())
......
...@@ -86,6 +86,12 @@ def set_max_seq_len(max_seq_len=128): ...@@ -86,6 +86,12 @@ def set_max_seq_len(max_seq_len=128):
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}" os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}"
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def create_fp8_recipe(): def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
......
...@@ -48,6 +48,7 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: ...@@ -48,6 +48,7 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
"""Custom func to test recipe.""" """Custom func to test recipe."""
return torch.min(amax_history, dim=0).values return torch.min(amax_history, dim=0).values
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Transformer model configuration""" """Transformer model configuration"""
...@@ -115,6 +116,12 @@ def _disable_wgrads(block): ...@@ -115,6 +116,12 @@ def _disable_wgrads(block):
p.requires_grad = False p.requires_grad = False
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer. # Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss() loss_fn = torch.nn.MSELoss()
...@@ -137,7 +144,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -137,7 +144,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
with torch.cuda.stream(s): with torch.cuda.stream(s):
for _ in range(3): for _ in range(3):
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
out = block(static_input) out = block(static_input)
loss = loss_fn(out, static_target) loss = loss_fn(out, static_target)
loss.backward() loss.backward()
...@@ -148,7 +155,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): ...@@ -148,7 +155,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g): with torch.cuda.graph(g):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True):
static_output = block(static_input) static_output = block(static_input)
static_loss = loss_fn(static_output, static_target) static_loss = loss_fn(static_output, static_target)
static_loss.backward() static_loss.backward()
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file transpose_with_noop.h
* \brief Functions handling transposes with no-op.
*/
#ifndef TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_
#define TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
void nvte_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor output,
cudaStream_t stream);
void nvte_cast_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_
...@@ -56,6 +56,45 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his ...@@ -56,6 +56,45 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his
float margin, float margin,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction.
*
* Operations performed include, updating the most recent amax history
* with the relevant segment of global reduction buffer if it's not 0,
* rotating the amax history based on the rule below, and updating the
* scales and scale_invs.
*
* The amax history is rotated by -1 (e.g. the first entry shifts to
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$.
*
* \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction.
* Shape: [num_scales * num_tensors]
* \param[in,out] amax_histories List of amax histories of maximum absolute values.
* Shape: num_tensors x [history_length, num_scales]
* \param[in,out] scales List of scaling factors for casting to FP8.
* Shape: num_tensors x [num_scales]
* \param[in,out] scale_invs List of scaling factors for casting from FP8.
* Shape: num_tensors x [num_scales]
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and
* "most_recent".
* \param[in] fp8_dtype FP8 datatype.
* \param[in] margin Scaling factor margin.
* \param[in] stream CUDA stream.
*/
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer,
std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales,
std::vector<NVTETensor> scale_invs,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -229,19 +229,29 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -229,19 +229,29 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr); NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = layer_norm::DType::kByte; workspace->data.dtype = layer_norm::DType::kByte;
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
workspace->data.shape = { launch_params.workspace_bytes }; workspace->data.shape = { launch_params.workspace_bytes };
barrier->data.dtype = layer_norm::DType::kInt32; barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = { launch_params.barrier_size }; barrier->data.shape = { launch_params.barrier_size };
return; return;
} else {
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data // Tensor checks are delayed here in order to recover workspace sizes with null data
...@@ -368,6 +378,27 @@ void layernorm_bwd(const Tensor& dz, ...@@ -368,6 +378,27 @@ void layernorm_bwd(const Tensor& dz,
barrier->data.shape = { launch_params.barrier_size }; barrier->data.shape = { launch_params.barrier_size };
return; return;
} else {
NVTE_CHECK(dbeta_part->data.dptr != nullptr);
auto pdw_shape = std::vector<size_t>{
static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
NVTE_CHECK(dbeta_part->data.dtype == ctype);
NVTE_CHECK(dbeta_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data // Tensor checks are delayed here in order to recover workspace sizes with null data
......
...@@ -133,3 +133,13 @@ class DelayedScaling: ...@@ -133,3 +133,13 @@ class DelayedScaling:
(False, False, False), (False, False, False),
(False, False, True), (False, False, True),
), "Only wgrad GEMM override is currently supported." ), "Only wgrad GEMM override is currently supported."
def __repr__(self) -> str:
return (
f"margin={self.margin}, "
f"interval={self.interval}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"wgrad_override={self.override_linear_precision.wgrad}, "
f"reduce_amax={self.reduce_amax}"
)
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "../common.h" #include "../common.h"
#include "../util/logging.h" #include "../util/logging.h"
#include "../util/cuda_runtime.h"
namespace transformer_engine { namespace transformer_engine {
namespace delayed_scaling_recipe { namespace delayed_scaling_recipe {
...@@ -38,6 +39,36 @@ inline float fp8_dtype_max(DType dtype) { ...@@ -38,6 +39,36 @@ inline float fp8_dtype_max(DType dtype) {
return 0; return 0;
} }
// struct for amax parameters
struct AmaxParam {
int num_scale = 0;
float* amax_history = nullptr;
float* scale = nullptr;
float* scale_inv = nullptr;
};
// dummy struct for kernel_bulk's other params
struct OtherParams {
float* a;
size_t b;
AmaxComputeAlgo c;
float d;
};
#if CUDART_VERSION >= 12010
constexpr size_t max_constant_memory_per_kernel = 32000;
constexpr size_t AMAX_PARAMS_LIMIT = (
max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
#else
constexpr size_t max_constant_memory_per_kernel = 4000;
constexpr size_t AMAX_PARAMS_LIMIT = (
max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
#endif
struct AmaxParams {
AmaxParam param[AMAX_PARAMS_LIMIT];
};
namespace amax_and_scale_update_impl { namespace amax_and_scale_update_impl {
// CUDA block size // CUDA block size
...@@ -133,11 +164,96 @@ kernel(const float* amax_history_ptr, ...@@ -133,11 +164,96 @@ kernel(const float* amax_history_ptr,
} }
} }
} // namespace amax_and_scale_update_impl /* CUDA kernel to bulk-update amax history and FP8 scaling factors
*
* Block dims: bsize x 1 x 1
*
* Grid dims: num_tensors x 1 x 1
*/
__global__ void __launch_bounds__(bsize)
kernel_bulk(
float* amax_reduction_buffer,
AmaxParams p,
size_t amax_history_length,
AmaxComputeAlgo amax_compute_algo,
float scaled_max) {
const size_t bid = blockIdx.x;
const size_t tid = threadIdx.x;
const int num_scale = p.param[bid].num_scale;
int offset_in_buffer = 0;
for (int j = 0; j < bid; j++) {
offset_in_buffer += p.param[j].num_scale;
}
for (int count = 0; count < num_scale; count++) {
// Update amax
float amax = 0;
{
// Roll amax history
const auto& length = amax_history_length;
const auto& stride = p.param[bid].num_scale;
auto* amax_history = p.param[bid].amax_history+count;
const auto last_amax = ((amax_reduction_buffer != nullptr)
&& (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ?
amax_reduction_buffer[offset_in_buffer+count] : amax_history[0];
for (size_t off = 0; off < length; off += bsize) {
const size_t i = off + tid;
float a = 0;
if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
amax = fmaxf(amax, a);
}
__syncthreads(); // Inplace roll
if (i < length) {
amax_history[i*stride] = (i > 0) ? a : 0;
}
}
// Compute amax to use for scaling factor
switch (amax_compute_algo) {
case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax;
break;
case AmaxComputeAlgo::MAX:
{
__shared__ float shared_amax[bsize];
shared_amax[tid] = amax;
__syncthreads();
#pragma unroll
for (size_t off = bsize / 2; off > 0; off /= 2) {
if (tid < off) {
shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
}
__syncthreads();
}
amax = shared_amax[tid];
}
break;
default:
amax = 0;
}
}
// Update scale and scale inverse
if (tid == 0) {
float scale;
if (isfinite(amax) && amax > 0) {
scale = scaled_max / amax;
} else {
scale = p.param[bid].scale[count];
}
p.param[bid].scale[count] = scale;
p.param[bid].scale_inv[count] = 1 / scale;
}
}
}
} // namespace amax_and_scale_update_impl
} // namespace } // namespace
void amax_and_scale_update(const Tensor &amax_history, void amax_and_scale_update(const Tensor &amax_history,
const Tensor &scale, const Tensor &scale,
const Tensor &scale_inv, const Tensor &scale_inv,
...@@ -238,9 +354,105 @@ void amax_and_scale_update(const Tensor &amax_history, ...@@ -238,9 +354,105 @@ void amax_and_scale_update(const Tensor &amax_history,
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer,
std::vector<Tensor*> amax_histories,
std::vector<Tensor*> scales,
std::vector<Tensor*> scale_invs,
const std::string &amax_compute_algo,
DType fp8_dtype,
float margin,
cudaStream_t stream) {
using namespace transformer_engine;
// amax value to use for updating scaling factor
AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
if (amax_compute_algo == "max") {
amax_compute_algo_ = AmaxComputeAlgo::MAX;
} else if (amax_compute_algo == "most_recent") {
amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
} else {
NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
}
// Expected maximum value after scale is applied
const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);
// Number of elements in tensor
auto numel = [] (const Tensor *tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor->data.shape) {
acc *= dim;
}
return acc;
};
// Number of tensors in the bulk
const size_t num_tensors = amax_histories.size();
const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT;
size_t amax_history_length = 0;
if (num_tensors > 0) {
amax_history_length = amax_histories[0]->data.shape[0];
}
// amax parameters
float* amax_buffer = static_cast<float*>(amax_reduction_buffer.data.dptr);
AmaxParams p;
for (int iter = 0; iter < num_kernels; iter++) {
size_t kernel_num_scales = 0;
size_t kernel_num_tensors = (iter == (num_kernels -1))
? num_tensors % AMAX_PARAMS_LIMIT: AMAX_PARAMS_LIMIT;
for (size_t pi = 0; pi < kernel_num_tensors; pi++) {
size_t i = iter * AMAX_PARAMS_LIMIT + pi;
// Check tensors
int num_scale = amax_histories[i]->data.shape[1];
NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32,
"Found ", dtype_name(amax_histories[i]->data.dtype), ".");
NVTE_CHECK(amax_histories[i]->data.shape.size() == 2,
"Found ", amax_histories[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale,
"Expected ", amax_history_length * num_scale, " elements, ",
"but found ", numel(amax_histories[i]), ".");
NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32,
"Found ", dtype_name(scales[i]->data.dtype), ".");
NVTE_CHECK(scales[i]->data.shape.size() == 1,
"Found ", scales[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(scales[i]) == num_scale,
"Expected ", num_scale, " elements, ",
"Found ", numel(scales[i]), ".");
// amax parameters
kernel_num_scales += num_scale;
p.param[pi].num_scale = num_scale;
p.param[pi].amax_history = static_cast<float*>(amax_histories[i]->data.dptr);
p.param[pi].scale = static_cast<float*>(scales[i]->data.dptr);
p.param[pi].scale_inv = static_cast<float*>(scale_invs[i]->data.dptr);
}
// Launch CUDA kernel
size_t grid_size = kernel_num_tensors;
const size_t block_size = amax_and_scale_update_impl::bsize;
amax_and_scale_update_impl::kernel_bulk
<<<grid_size, block_size, 0, stream>>>(
amax_buffer,
p,
amax_history_length,
amax_compute_algo_,
scaled_max);
NVTE_CHECK_CUDA(cudaGetLastError());
// shift amax buffer pointer
if (amax_buffer != nullptr) {
amax_buffer += kernel_num_scales;
}
}
}
} // namespace delayed_scaling_recipe } // namespace delayed_scaling_recipe
} // namespace transformer_engine } // namespace transformer_engine
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history,
const NVTETensor scale, const NVTETensor scale,
const NVTETensor scale_inv, const NVTETensor scale_inv,
...@@ -267,3 +479,33 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his ...@@ -267,3 +479,33 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his
margin, margin,
stream); stream);
} }
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer,
std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales,
std::vector<NVTETensor> scale_invs,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction);
using namespace transformer_engine;
size_t num_tensors = amax_histories.size();
std::vector<Tensor*> t_amax_histories, t_scales, t_scale_invs;
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i]));
t_scales.push_back(reinterpret_cast<Tensor*>(scales[i]));
t_scale_invs.push_back(reinterpret_cast<Tensor*>(scale_invs[i]));
}
delayed_scaling_recipe::amax_and_scale_update_after_reduction(
*reinterpret_cast<const Tensor*>(amax_reduction_buffer),
t_amax_histories,
t_scales,
t_scale_invs,
amax_compute_algo,
static_cast<DType>(fp8_dtype),
margin,
stream);
}
...@@ -153,21 +153,32 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -153,21 +153,32 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr); NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
workspace->data.shape = {launch_params.workspace_bytes}; workspace->data.shape = {launch_params.workspace_bytes};
barrier->data.dtype = DType::kInt32; barrier->data.dtype = DType::kInt32;
barrier->data.shape = {launch_params.barrier_size}; barrier->data.shape = {launch_params.barrier_size};
return; return;
} else {
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data // Tensor checks are delayed here in order to recover workspace sizes with null data
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma"); CheckInputTensor(gamma, "gamma");
...@@ -265,6 +276,23 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -265,6 +276,23 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
barrier->data.shape = {launch_params.barrier_size}; barrier->data.shape = {launch_params.barrier_size};
return; return;
} else {
auto pdw_shape = std::vector<size_t>{
static_cast<uint64_t>(launch_params.params.ctas_per_col), hidden_size};
NVTE_CHECK(dgamma_part->data.dtype == ctype);
NVTE_CHECK(dgamma_part->data.shape == pdw_shape);
}
if (launch_params.barrier_size > 0) {
NVTE_CHECK(barrier->data.dptr != nullptr);
NVTE_CHECK(barrier->data.dtype == DType::kInt32);
NVTE_CHECK(barrier->data.shape == std::vector<size_t>{ launch_params.barrier_size });
}
if (launch_params.workspace_bytes > 0) {
NVTE_CHECK(workspace->data.dptr != nullptr);
NVTE_CHECK(workspace->data.dtype == DType::kByte);
NVTE_CHECK(workspace->data.shape == std::vector<size_t>{ launch_params.workspace_bytes });
} }
// Tensor checks are delayed here in order to recover workspace sizes with null data // Tensor checks are delayed here in order to recover workspace sizes with null data
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
...@@ -56,6 +57,7 @@ template <int nvec_in, int nvec_out, typename CType, typename IType, typename OT ...@@ -56,6 +57,7 @@ template <int nvec_in, int nvec_out, typename CType, typename IType, typename OT
__global__ void __global__ void
__launch_bounds__(cast_transpose_num_threads) __launch_bounds__(cast_transpose_num_threads)
cast_transpose_kernel(const IType * const input, cast_transpose_kernel(const IType * const input,
const CType * const noop,
OType * const output_c, OType * const output_c,
OType * const output_t, OType * const output_t,
const CType * const scale_ptr, const CType * const scale_ptr,
...@@ -63,6 +65,8 @@ cast_transpose_kernel(const IType * const input, ...@@ -63,6 +65,8 @@ cast_transpose_kernel(const IType * const input,
const size_t row_length, const size_t row_length,
const size_t num_rows, const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
if (noop != nullptr && noop[0] == 1.0f) return;
using IVec = Vec<IType, nvec_in>; using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>; using OVec = Vec<OType, nvec_out>;
...@@ -163,6 +167,7 @@ template <int nvec_in, int nvec_out, typename CType, typename IType, typename OT ...@@ -163,6 +167,7 @@ template <int nvec_in, int nvec_out, typename CType, typename IType, typename OT
__global__ void __global__ void
__launch_bounds__(cast_transpose_num_threads) __launch_bounds__(cast_transpose_num_threads)
cast_transpose_kernel_notaligned(const IType * const input, cast_transpose_kernel_notaligned(const IType * const input,
const CType * const noop,
OType * const output_c, OType * const output_c,
OType * const output_t, OType * const output_t,
const CType * const scale_ptr, const CType * const scale_ptr,
...@@ -170,6 +175,8 @@ cast_transpose_kernel_notaligned(const IType * const input, ...@@ -170,6 +175,8 @@ cast_transpose_kernel_notaligned(const IType * const input,
const size_t row_length, const size_t row_length,
const size_t num_rows, const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
if (noop != nullptr && noop[0] == 1.0f) return;
using IVec = Vec<IType, nvec_in>; using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>; using OVec = Vec<OType, nvec_out>;
...@@ -294,6 +301,7 @@ cast_transpose_kernel_notaligned(const IType * const input, ...@@ -294,6 +301,7 @@ cast_transpose_kernel_notaligned(const IType * const input,
} }
void cast_transpose(const Tensor &input, void cast_transpose(const Tensor &input,
const Tensor &noop,
Tensor *cast_output, Tensor *cast_output,
Tensor *transposed_output, Tensor *transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -301,6 +309,22 @@ void cast_transpose(const Tensor &input, ...@@ -301,6 +309,22 @@ void cast_transpose(const Tensor &input,
CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output"); CheckOutputTensor(*transposed_output, "transposed_output");
// Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1,
"Expected 1 element, ",
"but found ", numel(noop), ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr);
}
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
...@@ -332,6 +356,7 @@ void cast_transpose(const Tensor &input, ...@@ -332,6 +356,7 @@ void cast_transpose(const Tensor &input,
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), \ (THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), \
stream>>>( \ stream>>>( \
reinterpret_cast<const InputType *>(input.data.dptr), \ reinterpret_cast<const InputType *>(input.data.dptr), \
reinterpret_cast<const fp32 *>(noop.data.dptr), \
reinterpret_cast<OutputType *>(cast_output->data.dptr), \ reinterpret_cast<OutputType *>(cast_output->data.dptr), \
reinterpret_cast<OutputType *>(transposed_output->data.dptr), \ reinterpret_cast<OutputType *>(transposed_output->data.dptr), \
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), \ reinterpret_cast<const fp32 *>(cast_output->scale.dptr), \
...@@ -417,7 +442,23 @@ void nvte_cast_transpose(const NVTETensor input, ...@@ -417,7 +442,23 @@ void nvte_cast_transpose(const NVTETensor input,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose); NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor();
cast_transpose(*reinterpret_cast<const Tensor*>(input),
noop,
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
void nvte_cast_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(noop),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
stream); stream);
......
...@@ -22,9 +22,12 @@ constexpr size_t block_size = __BLOCK_SIZE__; ...@@ -22,9 +22,12 @@ constexpr size_t block_size = __BLOCK_SIZE__;
__global__ void __global__ void
__launch_bounds__(block_size) __launch_bounds__(block_size)
transpose_optimized_kernel(const Type * __restrict__ const input, transpose_optimized_kernel(const Type * __restrict__ const input,
const float * const noop,
Type * __restrict__ const output, Type * __restrict__ const output,
const size_t row_length, const size_t row_length,
const size_t num_rows) { const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(Type); constexpr size_t nvec_in = load_size / sizeof(Type);
constexpr size_t nvec_out = store_size / sizeof(Type); constexpr size_t nvec_out = store_size / sizeof(Type);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
...@@ -30,9 +31,12 @@ template <size_t load_size, size_t store_size, typename Type> ...@@ -30,9 +31,12 @@ template <size_t load_size, size_t store_size, typename Type>
__global__ void __global__ void
__launch_bounds__(block_size) __launch_bounds__(block_size)
transpose_general_kernel(const Type * __restrict__ const input, transpose_general_kernel(const Type * __restrict__ const input,
const fp32 * const noop,
Type * __restrict__ const output, Type * __restrict__ const output,
const size_t row_length, const size_t row_length,
const size_t num_rows) { const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
constexpr size_t nvec_in = load_size / sizeof(Type); constexpr size_t nvec_in = load_size / sizeof(Type);
constexpr size_t nvec_out = store_size / sizeof(Type); constexpr size_t nvec_out = store_size / sizeof(Type);
...@@ -124,6 +128,7 @@ transpose_general_kernel(const Type * __restrict__ const input, ...@@ -124,6 +128,7 @@ transpose_general_kernel(const Type * __restrict__ const input,
} }
void transpose(const Tensor &input, void transpose(const Tensor &input,
const Tensor &noop,
Tensor *output_, Tensor *output_,
cudaStream_t stream) { cudaStream_t stream) {
Tensor &output = *output_; Tensor &output = *output_;
...@@ -140,6 +145,23 @@ void transpose(const Tensor &input, ...@@ -140,6 +145,23 @@ void transpose(const Tensor &input,
NVTE_CHECK(input.data.dtype == output.data.dtype, NVTE_CHECK(input.data.dtype == output.data.dtype,
"Input and output type must match."); "Input and output type must match.");
// Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1,
"Expected 1 element, ",
"but found ", numel(noop), ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type,
constexpr const char *type_name = TypeInfo<Type>::name; constexpr const char *type_name = TypeInfo<Type>::name;
constexpr size_t type_size = sizeof(Type); constexpr size_t type_size = sizeof(Type);
...@@ -239,6 +261,7 @@ void transpose(const Tensor &input, ...@@ -239,6 +261,7 @@ void transpose(const Tensor &input,
rtc_manager.launch(kernel_label, rtc_manager.launch(kernel_label,
num_blocks(load_size, store_size), block_size, 0, stream, num_blocks(load_size, store_size), block_size, 0, stream,
static_cast<const Type *>(input.data.dptr), static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type*>(output.data.dptr), static_cast<Type*>(output.data.dptr),
row_length, num_rows); row_length, num_rows);
} else { // Statically-compiled general kernel } else { // Statically-compiled general kernel
...@@ -250,6 +273,7 @@ void transpose(const Tensor &input, ...@@ -250,6 +273,7 @@ void transpose(const Tensor &input,
* DIVUP(num_rows, col_tile_size)); * DIVUP(num_rows, col_tile_size));
transpose_general_kernel<load_size, store_size, Type><<<num_blocks, block_size, 0, stream>>>( transpose_general_kernel<load_size, store_size, Type><<<num_blocks, block_size, 0, stream>>>(
static_cast<const Type *>(input.data.dptr), static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type *>(output.data.dptr), static_cast<Type *>(output.data.dptr),
row_length, num_rows); row_length, num_rows);
} }
...@@ -263,7 +287,22 @@ void nvte_transpose(const NVTETensor input, ...@@ -263,7 +287,22 @@ void nvte_transpose(const NVTETensor input,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose); NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor*>(input),
noop,
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input), transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(noop),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
...@@ -14,6 +14,7 @@ from .attention import MultiheadAttention ...@@ -14,6 +14,7 @@ from .attention import MultiheadAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .fp8 import fp8_model_init from .fp8 import fp8_model_init
from .graph import make_graphed_callables
from .export import onnx_export from .export import onnx_export
from .distributed import checkpoint from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker from .distributed import CudaRNGStatesTracker
......
...@@ -52,9 +52,14 @@ from transformer_engine.pytorch.distributed import ( ...@@ -52,9 +52,14 @@ from transformer_engine.pytorch.distributed import (
get_distributed_world_size, get_distributed_world_size,
get_distributed_rank, get_distributed_rank,
checkpoint, checkpoint,
set_all_rng_states,
CudaRNGStatesTracker,
graph_safe_rng_available,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6") _flash_attn_version_required = packaging.version.Version("2.0.6")
...@@ -2401,10 +2406,13 @@ class DotProductAttention(torch.nn.Module): ...@@ -2401,10 +2406,13 @@ class DotProductAttention(torch.nn.Module):
assert (num_attention_heads % self.num_gqa_groups == 0 assert (num_attention_heads % self.num_gqa_groups == 0
), "The number of attention heads must be divisible by the number of GQA groups!" ), "The number of attention heads must be divisible by the number of GQA groups!"
self.rng_states_tracker = None
if sequence_parallel or get_rng_state_tracker is None: if sequence_parallel or get_rng_state_tracker is None:
attention_dropout_ctx = nullcontext attention_dropout_ctx = nullcontext
else: else:
attention_dropout_ctx = get_rng_state_tracker().fork self.rng_states_tracker = get_rng_state_tracker()
set_all_rng_states(self.rng_states_tracker.get_states())
attention_dropout_ctx = self.rng_states_tracker.fork
norm_factor = math.sqrt(self.hidden_size_per_attention_head) norm_factor = math.sqrt(self.hidden_size_per_attention_head)
...@@ -2648,6 +2656,14 @@ class DotProductAttention(torch.nn.Module): ...@@ -2648,6 +2656,14 @@ class DotProductAttention(torch.nn.Module):
assert (attn_mask_type in AttnMaskTypes assert (attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!" ), f"Attention mask type {attn_mask_type} is not supported!"
if self.rng_states_tracker is not None and is_graph_capturing():
assert (
isinstance(self.rng_states_tracker, CudaRNGStatesTracker)
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
...@@ -3695,7 +3711,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3695,7 +3711,8 @@ class MultiheadAttention(torch.nn.Module):
# =================== # ===================
projection_output = self.proj( projection_output = self.proj(
context_layer, is_first_microbatch=is_first_microbatch context_layer,
is_first_microbatch=is_first_microbatch,
) )
if self.return_bias: if self.return_bias:
......
...@@ -22,19 +22,26 @@ def fp8_cast_transpose_fused( ...@@ -22,19 +22,26 @@ def fp8_cast_transpose_fused(
otype: tex.DType, otype: tex.DType,
cast_out: Optional[torch.Tensor] = None, cast_out: Optional[torch.Tensor] = None,
transpose_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None,
noop_flag: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: ) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
"""Cast + Transpose with FP8 output""" """Cast + Transpose with FP8 output"""
return_outputs = False return_outputs = False
if cast_out is None or transpose_out is None: if transpose_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8)
transpose_out = torch.empty( transpose_out = torch.empty(
inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8 inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8
) )
return_outputs = True return_outputs = True
if cast_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8)
return_outputs = True
if noop_flag is None:
noop_flag = torch.Tensor()
tex.fused_cast_transpose( tex.fused_cast_transpose_noop(
inp, inp,
noop_flag,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment