Commit 1edc9e13 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.8' into release_v2.8

parents 5e7dd67e 3a040217
...@@ -49,6 +49,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes ...@@ -49,6 +49,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
mkdir -p $TE_PATH/artifacts/tests/pytorch/test_checkpoint && python $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all --checkpoint-dir $TE_PATH/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
......
...@@ -340,22 +340,23 @@ def test_dpa_softmax(dtype, model_configs, model): ...@@ -340,22 +340,23 @@ def test_dpa_softmax(dtype, model_configs, model):
model_configs_mla = { model_configs_mla = {
# test: ModelConfig(b, sq, hq, dqk) #TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # # test: ModelConfig(b, sq, hq, dqk)
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_1": ModelConfig( # "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 # "mla_2_1": ModelConfig(
), # 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
"mla_2_2": ModelConfig( # ), # cross, 1
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 # "mla_2_2": ModelConfig(
), # 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # ), # cross, 1
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
} }
......
...@@ -17,6 +17,7 @@ import transformer_engine_torch as tex ...@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import ( from test_numerics import (
_emulate_linear, _emulate_linear,
...@@ -47,7 +48,6 @@ TEST_NR = 0 ...@@ -47,7 +48,6 @@ TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None: if tp_size is None:
tp_size = WORLD_SIZE tp_size = WORLD_SIZE
...@@ -72,13 +72,23 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, ...@@ -72,13 +72,23 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"): def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"):
model = transformer_engine.pytorch.Linear( if IS_HIP_EXTENSION:
model = transformer_engine.pytorch.Linear(
IN_SIZE, IN_SIZE,
OUT_SIZE, OUT_SIZE,
name=name, name=name,
bias=False,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
tp_group=(tp_group or NCCL_WORLD if parallel_mode else None), tp_group=(tp_group or NCCL_WORLD if parallel_mode else None),
) )
else:
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
name=name,
parallel_mode=parallel_mode,
tp_group=(tp_group or NCCL_WORLD if parallel_mode else None),
)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(weight) model.weight.copy_(weight)
return model return model
...@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs): ...@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
) )
set_weight_tensor_tp_group_reduce(True) # reset set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test @run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs): def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG from test_log import LOG_QUANTIZED_CONFIG
...@@ -580,6 +589,9 @@ def test_fake_quant_fp8( ...@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
"dgrad_fp8": not (dgrad_weight or dgrad_grad), "dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input), "wgrad_fp8": not (wgrad_grad or wgrad_input),
} }
if IS_HIP_EXTENSION:
if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]:
return # Output type 32 (FP32) does not support int8 simulation.
if WORLD_RANK == 0: if WORLD_RANK == 0:
fake_quant_fp8_create_config( fake_quant_fp8_create_config(
fprop_inp, fprop_inp,
...@@ -667,30 +679,42 @@ if __name__ == "__main__": ...@@ -667,30 +679,42 @@ if __name__ == "__main__":
random.seed(SEED) random.seed(SEED)
_init_distributed() _init_distributed()
test_log_expert_parallel() if IS_HIP_EXTENSION:
for parallel_mode in ["column", "row"]: # Output type 32 (FP32) does not support int8 simulation.
for gather_weight in [True, False]: pass
test_log_distributed(parallel_mode, gather_weight) else:
test_log_expert_parallel()
for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight)
if fp8_available: if fp8_available:
for parallel_mode in ["row", "column"]: for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode) test_disable_fp8_layer(parallel_mode)
# test_disable_fp8_gemms
_run_test_with_combinations(
test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
)
# test_fake_quant_fp8 if IS_HIP_EXTENSION:
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None] # Output type 32 (FP32) does not support int8 simulation.
pass
else:
# test_disable_fp8_gemms
_run_test_with_combinations( _run_test_with_combinations(
test_fake_quant_fp8, test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
dtype_options,
num_repeat=6,
extra_args=["column", "row"],
sample_size=20,
) )
# test_fake_quant_fp8
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None]
_run_test_with_combinations(
test_fake_quant_fp8,
dtype_options,
num_repeat=6,
extra_args=["column", "row"],
sample_size=20,
)
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
_run_test_with_combinations( _run_test_with_combinations(
test_per_tensor_scaling, test_per_tensor_scaling,
all_boolean, all_boolean,
......
...@@ -733,7 +733,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -733,7 +733,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
def test_linear(): def test_linear():
"""Run linear layer tests with various configurations.""" """Run linear layer tests with various configurations."""
kwargs_list = [ base_kwargs_list = [
{}, {},
{"bias": False}, {"bias": False},
{"init_method": _constant}, {"init_method": _constant},
...@@ -743,7 +743,15 @@ def test_linear(): ...@@ -743,7 +743,15 @@ def test_linear():
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
{"save_original_input": True}, {"save_original_input": True},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue continue
...@@ -913,7 +921,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -913,7 +921,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
def test_layernorm_linear(): def test_layernorm_linear():
kwargs_list = [ base_kwargs_list = [
{}, {},
{"bias": False}, {"bias": False},
{"init_method": _constant}, {"init_method": _constant},
...@@ -924,7 +932,15 @@ def test_layernorm_linear(): ...@@ -924,7 +932,15 @@ def test_layernorm_linear():
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
for parallel_mode in ["column"]: for parallel_mode in ["column"]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
...@@ -1019,7 +1035,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg ...@@ -1019,7 +1035,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
def test_layernorm_mlp(): def test_layernorm_mlp():
kwargs_list = [ base_kwargs_list = [
{}, {},
{"init_method": _constant}, {"init_method": _constant},
{"output_layer_init_method": _constant}, {"output_layer_init_method": _constant},
...@@ -1033,7 +1049,15 @@ def test_layernorm_mlp(): ...@@ -1033,7 +1049,15 @@ def test_layernorm_mlp():
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
for set_parallel_mode in [True]: for set_parallel_mode in [True]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
...@@ -1108,7 +1132,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs): ...@@ -1108,7 +1132,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
def test_transformer_layer(): def test_transformer_layer():
kwargs_list = [ base_kwargs_list = [
{}, {},
{"num_gqa_groups": 4}, {"num_gqa_groups": 4},
{"init_method": _constant}, {"init_method": _constant},
...@@ -1128,6 +1152,15 @@ def test_transformer_layer(): ...@@ -1128,6 +1152,15 @@ def test_transformer_layer():
{"fuse_qkv_params": True}, {"fuse_qkv_params": True},
{"activation": "relu"}, {"activation": "relu"},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
......
...@@ -9,7 +9,8 @@ from pathlib import Path ...@@ -9,7 +9,8 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine as te
""" """
Distributed numerics tests Distributed numerics tests
...@@ -66,4 +67,15 @@ def test_distributed(quantization): ...@@ -66,4 +67,15 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "nvfp4" and not nvfp4_available: if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4) pytest.skip(reason_for_no_nvfp4)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", "None")
os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8)
_run_test(quantization) _run_test(quantization)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
if ori_int8_sim_fp8 is None or ori_int8_sim_fp8 == "None":
os.environ["NVTE_INT8_SIM_FP8"] = "0"
else:
del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8)
...@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION: ...@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION:
from functools import cache from functools import cache
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Reset RNG states. # Reset RNG states.
reset_rng_states() reset_rng_states()
...@@ -367,6 +367,12 @@ def test_make_graphed_callables( ...@@ -367,6 +367,12 @@ def test_make_graphed_callables(
) )
if fp8_params: if fp8_params:
pytest.skip("NVFP4 params not supported") pytest.skip("NVFP4 params not supported")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Run model with different CUDA graph settings. # Run model with different CUDA graph settings.
model_config = model_configs[model_config] model_config = model_configs[model_config]
......
...@@ -6,7 +6,6 @@ import pytest ...@@ -6,7 +6,6 @@ import pytest
import torch import torch
import transformer_engine as te import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8) from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
......
...@@ -2111,7 +2111,8 @@ class TestFusedOps: ...@@ -2111,7 +2111,8 @@ class TestFusedOps:
quantized_weight: bool = False, quantized_weight: bool = False,
) -> None: ) -> None:
"""Forward GEMM + scale + add""" """Forward GEMM + scale + add"""
if IS_HIP_EXTENSION and scale != 1:
pytest.skip("alpha must be 1.0 for hip")
# Make input and weight shapes consistent # Make input and weight shapes consistent
out_features, in_features = weight_shape out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features] in_shape = list(in_shape)[:-1] + [in_features]
...@@ -2496,7 +2497,8 @@ class TestFusedOps: ...@@ -2496,7 +2497,8 @@ class TestFusedOps:
quantized_weight: bool = False, quantized_weight: bool = False,
) -> None: ) -> None:
"""Backward dgrad GEMM + scale""" """Backward dgrad GEMM + scale"""
if IS_HIP_EXTENSION and scale != 1:
pytest.skip("alpha must be 1.0 for hip")
# Make input and weight shapes consistent # Make input and weight shapes consistent
out_features, in_features = weight_shape out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features] in_shape = list(in_shape)[:-1] + [in_features]
......
...@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend ...@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
...@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute( ...@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
config = model_configs[model] config = model_configs[model]
...@@ -714,8 +721,15 @@ def _test_e2e_full_recompute( ...@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
def test_gpt_full_activation_recompute( def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation = True fuse_wgrad_accumulation = True
fp8_model_params = False fp8_model_params = False
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy( ...@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
use_cutlass=False, use_cutlass=False,
): ):
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
...@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy( ...@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
weight_i = getattr(grouped_linear, f"weight{i}") weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
outputs_ref = _test_grouped_linear_accuracy( outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, sequential_linear,
num_gemms, num_gemms,
...@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy( ...@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
delay_wgrad_compute, delay_wgrad_compute,
) )
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
for o, o_ref in zip(outputs, outputs_ref): for o, o_ref in zip(outputs, outputs_ref):
if use_cutlass: if use_cutlass:
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
...@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy( ...@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params, fp8_model_params,
parallel_mode=None, parallel_mode=None,
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
......
...@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op ...@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table from transformer_engine.pytorch.onnx_extensions import te_translation_table
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt import tensorrt as trt
...@@ -65,7 +67,6 @@ if mxfp8_available: ...@@ -65,7 +67,6 @@ if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling()) fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available: if fp8_available:
fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None) fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
...@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"] ...@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
], ],
outputs=[PyCustomOpDef.dt_uint8], outputs=[PyCustomOpDef.dt_uint8],
) )
def trt_fp8_quantize(t, scale_inv): def trt_fp8_quantize(t, scale):
"""FP8 quantization extension for ONNX Runtime.""" """FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda() x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer( q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale_inv).cuda(), scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(), amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
...@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv): ...@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
], ],
outputs=[PyCustomOpDef.dt_float], outputs=[PyCustomOpDef.dt_float],
) )
def trt_fp8_dequantize(t, scale_inv): def trt_fp8_dequantize(t, scale):
"""FP8 dequantization extension for ONNX Runtime.""" """FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda() x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer( q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale_inv).cuda(), scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(), amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
...@@ -469,16 +470,22 @@ def _test_export_linear( ...@@ -469,16 +470,22 @@ def _test_export_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision): def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(fp8_recipe=fp8_recipe, precision=precision) _test_export_linear(fp8_recipe=fp8_recipe, precision=precision)
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_export_linear_use_bias(seed_default_rng, use_bias): def test_export_linear_use_bias(seed_default_rng, use_bias):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(use_bias=use_bias) _test_export_linear(use_bias=use_bias)
@pytest.mark.parametrize("return_bias", [True, False]) @pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias): def test_export_linear_return_bias(seed_default_rng, return_bias):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(return_bias=return_bias) _test_export_linear(return_bias=return_bias)
...@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng): ...@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_normalization(seed_default_rng, normalization): def test_export_layernorm_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm(normalization=normalization) _test_export_layernorm(normalization=normalization)
...@@ -594,9 +603,7 @@ def _test_export_layernorm_linear( ...@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
fname, fname,
inp, inp,
model, model,
# For current scaling we use Float8Quantizer in tests + amax computed by hand, atol=1e-3,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
is_fp8=fp8_recipe is not None, is_fp8=fp8_recipe is not None,
te_outputs=te_outputs, te_outputs=te_outputs,
) )
...@@ -605,27 +612,39 @@ def _test_export_layernorm_linear( ...@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision): def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision) _test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_linear_return_ln_out(seed_default_rng): def test_export_layernorm_linear_return_ln_out(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(return_layernorm_output=True) _test_export_layernorm_linear(return_layernorm_output=True)
def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng): def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(zero_centered_gamma=True) _test_export_layernorm_linear(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:]) @pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_linear_normalization(seed_default_rng, normalization): def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(normalization=normalization) _test_export_layernorm_linear(normalization=normalization)
def test_export_layernorm_linear_no_bias(seed_default_rng): def test_export_layernorm_linear_no_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(use_bias=False) _test_export_layernorm_linear(use_bias=False)
def test_export_layernorm_linear_return_bias(seed_default_rng): def test_export_layernorm_linear_return_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(return_bias=True) _test_export_layernorm_linear(return_bias=True)
...@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp( ...@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision): def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision) _test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng): def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(return_layernorm_output=True) _test_export_layernorm_mlp(return_layernorm_output=True)
def test_export_layernorm_mlp_return_bias(seed_default_rng): def test_export_layernorm_mlp_return_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(return_bias=True) _test_export_layernorm_mlp(return_bias=True)
def test_export_layernorm_mlp_no_bias(seed_default_rng): def test_export_layernorm_mlp_no_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(use_bias=False) _test_export_layernorm_mlp(use_bias=False)
def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng): def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(zero_centered_gamma=True) _test_export_layernorm_mlp(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:]) @pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization): def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(normalization=normalization) _test_export_layernorm_mlp(normalization=normalization)
@pytest.mark.parametrize("activation", supported_activations[1:]) @pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation): def test_export_layernorm_mlp_activation(seed_default_rng, activation):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(activation=activation) _test_export_layernorm_mlp(activation=activation)
...@@ -731,6 +764,8 @@ def test_export_core_attention( ...@@ -731,6 +764,8 @@ def test_export_core_attention(
use_mask: bool, use_mask: bool,
attn_mask_type: str, attn_mask_type: str,
): ):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
...@@ -932,22 +967,32 @@ def _test_export_multihead_attention( ...@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_multihead_attention_recipe(fp8_recipe, precision): def test_export_multihead_attention_recipe(fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision) _test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)
def test_export_multihead_attention_no_mask(): def test_export_multihead_attention_no_mask():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(use_mask=False) _test_export_multihead_attention(use_mask=False)
def test_export_multihead_attention_no_input_layernorm(): def test_export_multihead_attention_no_input_layernorm():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(input_layernorm=False) _test_export_multihead_attention(input_layernorm=False)
def test_export_multihead_attention_cross_attn(): def test_export_multihead_attention_cross_attn():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(attention_type="cross") _test_export_multihead_attention(attention_type="cross")
def test_export_multihead_attention_unfused_qkv_params(): def test_export_multihead_attention_unfused_qkv_params():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(fuse_qkv_params=False) _test_export_multihead_attention(fuse_qkv_params=False)
...@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer( ...@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision): def test_export_transformer_layer_recipe(fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision) _test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)
def test_export_transformer_layer_no_mask(): def test_export_transformer_layer_no_mask():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(use_mask=False) _test_export_transformer_layer(use_mask=False)
def test_export_transformer_layer_output_layernorm(): def test_export_transformer_layer_output_layernorm():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(output_layernorm=True) _test_export_transformer_layer(output_layernorm=True)
def test_export_transformer_layer_unfused_qkv_params(): def test_export_transformer_layer_unfused_qkv_params():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(fuse_qkv_params=False) _test_export_transformer_layer(fuse_qkv_params=False)
def test_export_transformer_layer_zero_centered_gamma(): def test_export_transformer_layer_zero_centered_gamma():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(zero_centered_gamma=True) _test_export_transformer_layer(zero_centered_gamma=True)
@pytest.mark.parametrize("activation", supported_activations[1:]) @pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation): def test_export_transformer_layer_activation(activation):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(activation=activation) _test_export_transformer_layer(activation=activation)
...@@ -1056,7 +1113,8 @@ def test_export_gpt_generation( ...@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that """Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths. the attention mask is adjusted on-the-fly to different sequence lengths.
""" """
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
# Layer configuration # Layer configuration
hidden_size = 64 hidden_size = 64
sequence_length = 128 sequence_length = 128
...@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled): ...@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_trt_integration(fp8_recipe: recipe.Recipe): def test_trt_integration(fp8_recipe: recipe.Recipe):
if IS_HIP_EXTENSION:
pytest.skip("TRT is not supported for HIP")
model = te.TransformerLayer( model = te.TransformerLayer(
hidden_size=128, hidden_size=128,
ffn_hidden_size=128, ffn_hidden_size=128,
num_attention_heads=4, num_attention_heads=4,
).eval() ).eval()
if type(fp8_recipe) == recipe.Float8CurrentScaling:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model = te.LayerNormMLP(128, 128)
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
......
...@@ -46,7 +46,7 @@ from utils import ModelConfig ...@@ -46,7 +46,7 @@ from utils import ModelConfig
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run. # Record initial RNG state from script run.
...@@ -388,7 +388,13 @@ def test_sanity_layernorm_linear( ...@@ -388,7 +388,13 @@ def test_sanity_layernorm_linear(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16: if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
...@@ -450,7 +456,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -450,7 +456,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens = bs * config.max_seqlen_q num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16: if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
...@@ -543,7 +555,13 @@ def test_sanity_layernorm_mlp( ...@@ -543,7 +555,13 @@ def test_sanity_layernorm_mlp(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16: if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
...@@ -587,7 +605,13 @@ def test_sanity_gpt( ...@@ -587,7 +605,13 @@ def test_sanity_gpt(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16: if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
...@@ -759,7 +783,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -759,7 +783,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16: if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
...@@ -791,7 +821,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): ...@@ -791,7 +821,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16: if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
...@@ -827,7 +863,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -827,7 +863,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16: if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
...@@ -863,7 +905,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra ...@@ -863,7 +905,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16: if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported") pytest.skip("FP16 output for NVFP4 not supported")
......
...@@ -1561,15 +1561,6 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor ...@@ -1561,15 +1561,6 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
NVTE_ERROR("TT layout not allowed."); NVTE_ERROR("TT layout not allowed.");
} }
hipblasLtHandle_t handle = nullptr;
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[0];
NVTE_ERROR("Remove nvte_cublas_batchgemm_tensorwise_int8 for now."); NVTE_ERROR("Remove nvte_cublas_batchgemm_tensorwise_int8 for now.");
} }
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
* *
* License for AMD contributions = MIT. See LICENSE for more information * License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/ ************************************************************************/
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <type_traits> #include <type_traits>
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt.h> #include <hipblaslt/hipblaslt.h>
#include <unistd.h> #include <unistd.h>
#include <chrono> #include <chrono>
#include <forward_list> #include <forward_list>
#include <fstream> #include <fstream>
#include <mutex> #include <mutex>
#include <optional> #include <optional>
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#endif #endif
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
#define ROCBLAS_BETA_FEATURES_API #define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <hipblaslt/hipblaslt-ext.hpp> #include <hipblaslt/hipblaslt-ext.hpp>
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
#endif #endif
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "../common.h" #include "../common.h"
#include "../util/handle_manager.h" #include "../util/handle_manager.h"
#include "../util/logging.h" #include "../util/logging.h"
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
namespace { namespace {
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) { static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case DType::kFloat16: case DType::kFloat16:
return HIP_R_16F; return HIP_R_16F;
case DType::kFloat32: case DType::kFloat32:
return HIP_R_32F; return HIP_R_32F;
case DType::kBFloat16: case DType::kBFloat16:
return HIP_R_16BF; return HIP_R_16BF;
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
return HIP_R_8F_E4M3; return HIP_R_8F_E4M3;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return HIP_R_8F_E5M2; return HIP_R_8F_E5M2;
case DType::kInt8: case DType::kInt8:
return HIP_R_8I; return HIP_R_8I;
case DType::kInt32: case DType::kInt32:
return HIP_R_32I; return HIP_R_32I;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
} }
#endif #endif
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
rocblas_datatype get_rocblas_dtype(const transformer_engine::DType t) { rocblas_datatype get_rocblas_dtype(const transformer_engine::DType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case DType::kFloat16: case DType::kFloat16:
return rocblas_datatype_f16_r; return rocblas_datatype_f16_r;
case DType::kFloat32: case DType::kFloat32:
return rocblas_datatype_f32_r; return rocblas_datatype_f32_r;
case DType::kBFloat16: case DType::kBFloat16:
return rocblas_datatype_bf16_r; return rocblas_datatype_bf16_r;
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
return rocblas_datatype_f8_r; return rocblas_datatype_f8_r;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return rocblas_datatype_bf8_r; return rocblas_datatype_bf8_r;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
} }
#endif #endif
} //namespace } //namespace
namespace transformer_engine { namespace transformer_engine {
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
namespace detail { namespace detail {
struct Empty {}; struct Empty {};
__device__ inline fp32 identity(fp32 value, const Empty&) { return value; } __device__ inline fp32 identity(fp32 value, const Empty&) { return value; }
__inline__ __device__ float gelu(float x, const Empty&) { __inline__ __device__ float gelu(float x, const Empty&) {
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf; return x * cdf;
} }
__inline__ __device__ float gelu_forward(float x) { __inline__ __device__ float gelu_forward(float x) {
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf; return x * cdf;
} }
template <typename T, int THREADS_PER_BLOCK> template <typename T, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
gelu_forward_kernel(const float* in, T* out, float* amax, const float* scale, int m, int n) { gelu_forward_kernel(const float* in, T* out, float* amax, const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) { if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float x = in[id]; float x = in[id];
float y = gelu_forward(x); float y = gelu_forward(x);
out[id] = (T)((*scale) * y); out[id] = (T)((*scale) * y);
thread_amax = std::fmax(std::fabs(y), thread_amax); thread_amax = std::fmax(std::fabs(y), thread_amax);
} }
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
} else { } else {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float x = in[id]; float x = in[id];
float y = gelu_forward(x); float y = gelu_forward(x);
out[id] = (T)(y); out[id] = (T)(y);
} }
} }
} }
template <typename T> template <typename T>
void gelu_forward_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int m, void gelu_forward_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int m,
int n, hipStream_t stream) { int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL((gelu_forward_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, hipLaunchKernelGGL((gelu_forward_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, amax, scale, m, n); stream, in, out, amax, scale, m, n);
} }
__inline__ __device__ float gelu_backward(float x, float dy) { __inline__ __device__ float gelu_backward(float x, float dy) {
constexpr float kBeta = 0.7978845608028654f; constexpr float kBeta = 0.7978845608028654f;
constexpr float kKappa = 0.044715f; constexpr float kKappa = 0.044715f;
float x_sq = x * x; float x_sq = x * x;
float x_cube = x_sq * x; float x_cube = x_sq * x;
float tanh_inner = tanhf((kBeta * (x + kKappa * x_cube))); float tanh_inner = tanhf((kBeta * (x + kKappa * x_cube)));
float left = 0.5 * x; float left = 0.5 * x;
float right = 1.0f + tanh_inner; float right = 1.0f + tanh_inner;
float left_derivative = 0.5 * right; float left_derivative = 0.5 * right;
float tanh_derivative = 1 - tanh_inner * tanh_inner; float tanh_derivative = 1 - tanh_inner * tanh_inner;
float inner_derivative = kBeta * (1.0f + 3.0 * kKappa * x_sq); float inner_derivative = kBeta * (1.0f + 3.0 * kKappa * x_sq);
float right_derivative = left * tanh_derivative * inner_derivative; float right_derivative = left * tanh_derivative * inner_derivative;
return dy * (left_derivative + right_derivative); return dy * (left_derivative + right_derivative);
} }
template <typename T, typename Taux> template <typename T, typename Taux>
__global__ void gelu_backward_kernel(const float* dy, T* out, const Taux* __restrict pre_gelu_out, __global__ void gelu_backward_kernel(const float* dy, T* out, const Taux* __restrict pre_gelu_out,
int m, int n) { int m, int n) {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float x = (float)pre_gelu_out[id]; float x = (float)pre_gelu_out[id];
float dx = (float)gelu_backward(x, dy[id]); float dx = (float)gelu_backward(x, dy[id]);
out[id] = (T)(dx); out[id] = (T)(dx);
} }
} }
template <typename T, typename Taux> template <typename T, typename Taux>
void gelu_backward_kernelLauncher(const float* in, T* out, const Taux* pre_gelu_out, int m, int n, void gelu_backward_kernelLauncher(const float* in, T* out, const Taux* pre_gelu_out, int m, int n,
hipStream_t stream) { hipStream_t stream) {
int blocks_per_row = ceil(float(n) / 256); int blocks_per_row = ceil(float(n) / 256);
dim3 grid(min(m * blocks_per_row, 65536)); dim3 grid(min(m * blocks_per_row, 65536));
dim3 block(min(n, 256)); dim3 block(min(n, 256));
hipLaunchKernelGGL((gelu_backward_kernel<T, Taux>), dim3(grid), dim3(block), 0, stream, in, out, hipLaunchKernelGGL((gelu_backward_kernel<T, Taux>), dim3(grid), dim3(block), 0, stream, in, out,
pre_gelu_out, m, n); pre_gelu_out, m, n);
} }
template <typename T, typename Tb, int THREADS_PER_BLOCK> template <typename T, typename Tb, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
add_bias_kernel(const float* in, T* out, const Tb* __restrict bias, float* amax, add_bias_kernel(const float* in, T* out, const Tb* __restrict bias, float* amax,
const float* scale, int m, int n) { const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) { if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
out[id] = (T)((*scale) * val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax = std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
} else { } else {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
out[id] = (T)(val); out[id] = (T)(val);
} }
} }
} }
template <typename T, typename Tb> template <typename T, typename Tb>
void add_bias_kernelLauncher(const float* in, T* out, const Tb* __restrict bias, float* amax, void add_bias_kernelLauncher(const float* in, T* out, const Tb* __restrict bias, float* amax,
const float* scale, int m, int n, hipStream_t stream) { const float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL((add_bias_kernel<T, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, hipLaunchKernelGGL((add_bias_kernel<T, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, bias, amax, scale, m, n); stream, in, out, bias, amax, scale, m, n);
} }
template <typename T, typename Taux, typename Tb, int THREADS_PER_BLOCK> template <typename T, typename Taux, typename Tb, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
add_bias_gelu_kernel(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias, add_bias_gelu_kernel(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias,
float* amax, const float* scale, int m, int n) { float* amax, const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) { if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
// only need to deal with amax and scale of D, no need to deal with amax and scale of pre_gelu_out // only need to deal with amax and scale of D, no need to deal with amax and scale of pre_gelu_out
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
// pre_gelu_out guaranteed not to be fp8 type // pre_gelu_out guaranteed not to be fp8 type
pre_gelu_out[id] = (Taux)(val); pre_gelu_out[id] = (Taux)(val);
val = gelu_forward(val); val = gelu_forward(val);
out[id] = (T)((*scale) * val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax = std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
} else { } else {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
pre_gelu_out[id] = (Taux)(val); pre_gelu_out[id] = (Taux)(val);
out[id] = (T)(gelu_forward(val)); out[id] = (T)(gelu_forward(val));
} }
} }
} }
template <typename T, typename Taux, typename Tb> template <typename T, typename Taux, typename Tb>
void add_bias_gelu_kernelLauncher(const float* in, T* out, Taux* pre_gelu_out, void add_bias_gelu_kernelLauncher(const float* in, T* out, Taux* pre_gelu_out,
const Tb* __restrict bias, float* amax, const float* scale, int m, const Tb* __restrict bias, float* amax, const float* scale, int m,
int n, hipStream_t stream) { int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL((add_bias_gelu_kernel<T, Taux, Tb, THREADS_PER_BLOCK>), dim3(grid), hipLaunchKernelGGL((add_bias_gelu_kernel<T, Taux, Tb, THREADS_PER_BLOCK>), dim3(grid),
dim3(block), 0, stream, in, out, pre_gelu_out, bias, amax, scale, m, n); dim3(block), 0, stream, in, out, pre_gelu_out, bias, amax, scale, m, n);
} }
template <typename Tin, typename T> template <typename Tin, typename T>
__global__ void identity_kernel(const Tin* in, T* out, int n) { __global__ void identity_kernel(const Tin* in, T* out, int n) {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
Tin val = in[id]; Tin val = in[id];
out[id] = (T)(val); out[id] = (T)(val);
} }
} }
template <typename Tin, typename T> template <typename Tin, typename T>
void identity_kernelLauncher(const Tin* in, T* out, int n, hipStream_t stream) { void identity_kernelLauncher(const Tin* in, T* out, int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
block.x = 256; block.x = 256;
grid.x = ceil(n / 256.); grid.x = ceil(n / 256.);
hipLaunchKernelGGL((identity_kernel<Tin, T>), dim3(grid), dim3(block), 0, stream, in, out, n); hipLaunchKernelGGL((identity_kernel<Tin, T>), dim3(grid), dim3(block), 0, stream, in, out, n);
} }
template <typename T, int THREADS_PER_BLOCK> template <typename T, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
identity_output_kernel(const float* in, T* out, float* amax, const float* scale, int n) { identity_output_kernel(const float* in, T* out, float* amax, const float* scale, int n) {
if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) { if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
float val = in[id]; float val = in[id];
out[id] = (T)((*scale) * val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax = std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
} else { } else {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
float val = in[id]; float val = in[id];
out[id] = (T)(val); out[id] = (T)(val);
} }
} }
} }
template <typename T> template <typename T>
void identity_output_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int n, void identity_output_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int n,
hipStream_t stream) { hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0 * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL((identity_output_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, hipLaunchKernelGGL((identity_output_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, amax, scale, n); stream, in, out, amax, scale, n);
} }
template <typename Tin, int THREADS_PER_BLOCK> template <typename Tin, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
bias_gradient_kernel(const Tin* in, float* out, int m, int n) { bias_gradient_kernel(const Tin* in, float* out, int m, int n) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK); int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK; int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK;
int idx = threadIdx.x + blockIdx.x * blockDim.x; int idx = threadIdx.x + blockIdx.x * blockDim.x;
int col_idx = idx / THREADS_PER_COL; int col_idx = idx / THREADS_PER_COL;
int row_idx = idx % THREADS_PER_COL; int row_idx = idx % THREADS_PER_COL;
float thread_data; float thread_data;
if (row_idx < m) thread_data = (float)in[row_idx * n + col_idx]; if (row_idx < m) thread_data = (float)in[row_idx * n + col_idx];
float local_sum; float local_sum;
if (row_idx < (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK) { if (row_idx < (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK) {
local_sum = BlockReduce(block_temp_storage).Sum(thread_data); local_sum = BlockReduce(block_temp_storage).Sum(thread_data);
} else { } else {
local_sum = BlockReduce(block_temp_storage) local_sum = BlockReduce(block_temp_storage)
.Sum(thread_data, m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK); .Sum(thread_data, m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK);
} }
if (threadIdx.x == 0) atomicAdd(&out[col_idx], local_sum); if (threadIdx.x == 0) atomicAdd(&out[col_idx], local_sum);
} }
constexpr int kColwiseReduceTileSize = 32; constexpr int kColwiseReduceTileSize = 32;
template <typename T> template <typename T>
__inline__ __device__ T WarpReduceSum(T val, int max = 32) { __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) { for (int offset = max; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset); val += __shfl_down(val, offset);
} }
return val; return val;
} }
template <typename InputType> template <typename InputType>
__launch_bounds__(1024) __global__ __launch_bounds__(1024) __global__
void bias_gradient_kernel_v2(float* dst, const InputType* src, int M, int N) { void bias_gradient_kernel_v2(float* dst, const InputType* src, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize]; __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x; const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f; float grad_sum = 0.f;
if (j < N) { if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) { for (int i = threadIdx.y; i < M; i += blockDim.y) {
grad_sum += static_cast<float>(src[i * N + j]); grad_sum += static_cast<float>(src[i * N + j]);
} }
} }
g_shared[threadIdx.y][threadIdx.x] = grad_sum; g_shared[threadIdx.y][threadIdx.x] = grad_sum;
__syncthreads(); __syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y]; float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2); sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y; const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) { if (j < N) {
dst[j] = static_cast<float>(sum); dst[j] = static_cast<float>(sum);
} }
} }
} }
template <typename OutputType> template <typename OutputType>
__launch_bounds__(1024) __global__ __launch_bounds__(1024) __global__
void tensorwise_int8_bias_gradient_kernel(OutputType* dst, const int8_t* src, float* scale, int M, int N) { void tensorwise_int8_bias_gradient_kernel(OutputType* dst, const int8_t* src, float* scale, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize]; __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x; const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f; float grad_sum = 0.f;
float tensorwise_scale = scale[0]; float tensorwise_scale = scale[0];
if (j < N) { if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) { for (int i = threadIdx.y; i < M; i += blockDim.y) {
grad_sum += static_cast<float>(src[i * N + j]) * tensorwise_scale; grad_sum += static_cast<float>(src[i * N + j]) * tensorwise_scale;
} }
} }
g_shared[threadIdx.y][threadIdx.x] = grad_sum; g_shared[threadIdx.y][threadIdx.x] = grad_sum;
__syncthreads(); __syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y]; float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2); sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y; const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) { if (j < N) {
dst[j] = static_cast<OutputType>(sum); dst[j] = static_cast<OutputType>(sum);
} }
} }
} }
template <typename Tin> template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc, void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc,
hipStream_t stream) { hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK); int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL * n; grid.x = BLOCKS_PER_COL * n;
if (!stream_order_alloc) { if (!stream_order_alloc) {
NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(float))); NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(float)));
} else { } else {
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(float), stream)); NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(float), stream));
} }
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n); // hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int B = (n - 1) / kColwiseReduceTileSize + 1; int B = (n - 1) / kColwiseReduceTileSize + 1;
bias_gradient_kernel_v2<Tin> bias_gradient_kernel_v2<Tin>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n); <<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
} }
template <typename Tout> template <typename Tout>
void tensorwise_int8_bias_gradient_kernelLauncher(const int8_t* in, Tout* out, float* scale, int m, int n, hipStream_t stream) { void tensorwise_int8_bias_gradient_kernelLauncher(const int8_t* in, Tout* out, float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK); int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL * n; grid.x = BLOCKS_PER_COL * n;
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream)); NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream));
int B = (n - 1) / kColwiseReduceTileSize + 1; int B = (n - 1) / kColwiseReduceTileSize + 1;
tensorwise_int8_bias_gradient_kernel<Tout> tensorwise_int8_bias_gradient_kernel<Tout>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, scale, m, n); <<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, scale, m, n);
} }
} // namespace detail } // namespace detail
transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) { transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case rocblas_datatype_f16_r: case rocblas_datatype_f16_r:
return DType::kFloat16; return DType::kFloat16;
case rocblas_datatype_f32_r: case rocblas_datatype_f32_r:
return DType::kFloat32; return DType::kFloat32;
case rocblas_datatype_bf16_r: case rocblas_datatype_bf16_r:
return DType::kBFloat16; return DType::kBFloat16;
case rocblas_datatype_f8_r: case rocblas_datatype_f8_r:
return DType::kFloat8E4M3; return DType::kFloat8E4M3;
case rocblas_datatype_bf8_r: case rocblas_datatype_bf8_r:
return DType::kFloat8E5M2; return DType::kFloat8E5M2;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
} }
#endif //USE_ROCBLAS #endif //USE_ROCBLAS
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
namespace { namespace {
static class HandlePool { class csv_helper {
public: public:
hipblasLtHandle_t get(int device_id) { struct start {};
std::lock_guard<std::mutex> lock(mt); struct end {};
if (pool.empty()) { csv_helper(std::ostream& os, char sep_val)
int device_count = 0; : m_os{os}, m_sep_val(sep_val), m_start(true), m_sep("") {}
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
pool.resize(device_count); csv_helper& operator<<(const start&) {
return nullptr; m_start = true;
} return *this;
}
if (!pool[device_id].empty()) {
hipblasLtHandle_t h = pool[device_id].front(); csv_helper& operator<<(const end&) {
pool[device_id].pop_front(); m_sep = "";
return h; m_start = false;
} return *this;
}
return nullptr;
} template <typename T>
csv_helper& operator<<(const T& v) {
hipblasLtHandle_t obtain(int device_id) { m_os << m_sep << v;
hipblasLtHandle_t h = get(device_id); if (m_start) {
if (h == nullptr) { m_start = false;
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h)); m_sep = m_sep_val;
} }
return h; return *this;
} }
void store(const std::vector<hipblasLtHandle_t>& handles) { private:
std::lock_guard<std::mutex> lock(mt); std::ostream& m_os;
if (pool.empty()) { char m_sep_val;
std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl; bool m_start;
} std::string m_sep;
for (unsigned int i = 0; i < pool.size(); i++) { };
if (handles[i] != nullptr) {
pool[i].push_front(handles[i]); template <typename T>
} class NameMapper {
} public:
} NameMapper(const std::unordered_map<T, std::string_view>& name_map) : map(name_map) {}
const std::string_view& getName(const T& val) { return map.at(val); }
~HandlePool() { T getValue(const std::string& name, const char* label = "",
#if DESTROY_HIPBLASLT_HANDLES_POOL std::function<bool(const T&)> filter = nullptr) {
std::lock_guard<std::mutex> lock(mt); for (auto iter = map.begin(); iter != map.end(); ++iter) {
for (auto& hlist : pool) { if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first;
for (auto& h : hlist) { }
hipblasLtDestroy(h); NVTE_ERROR("Invalid ", label, " name: ", name);
} }
}
pool.clear(); protected:
#endif const std::unordered_map<T, std::string_view>& map;
} };
inline size_t get_size() const { return pool.size(); } static std::unordered_map<hipDataType, std::string_view> type_name_map = {
{HIP_R_32F, "float32"},
private: {HIP_R_16F, "float16"},
std::mutex mt; {HIP_R_16BF, "bfloat16"},
using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>; {HIP_R_8F_E4M3_FNUZ, "float8e4m3"},
// Order of destructors between thread_local and global is not actually guaranteed {HIP_R_8F_E5M2_FNUZ, "float8e5m2"},
// As a simple w/a make pool storage "leaky" #if HIP_VERSION >= 60300000
// Just do not destruct it and do not destroy hipbladLt handles {HIP_R_8F_E4M3, "float8e4m3"},
// Let OS deal with it on application exit {HIP_R_8F_E5M2, "float8e5m2"},
#if DESTROY_HIPBLASLT_HANDLES_POOL #endif
Pool pool; };
#else static NameMapper<hipDataType> typeNameMapper(type_name_map);
Pool& pool = *new Pool();
#endif static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
} handle_pool; {HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_T, "T"}};
static NameMapper<hipblasOperation_t> transposeNameMapper(trans_name_map);
thread_local static class HandleCache {
public: static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = {
hipblasLtHandle_t get(int device_id) const { return d.empty() ? nullptr : d[device_id]; } {HIPBLASLT_EPILOGUE_DEFAULT, "-"}, {HIPBLASLT_EPILOGUE_BIAS, "bias"},
{HIPBLASLT_EPILOGUE_GELU_AUX, "geluaux"}, {HIPBLASLT_EPILOGUE_GELU_AUX_BIAS, "geluauxbias"},
hipblasLtHandle_t obtain(int device_id) { {HIPBLASLT_EPILOGUE_DGELU, "dgelu"}, {HIPBLASLT_EPILOGUE_DGELU_BGRAD, "dgelubgrad"},
hipblasLtHandle_t h = get(device_id); {HIPBLASLT_EPILOGUE_BGRADB, "bgradb"}};
if (h) { static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
return h;
} static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = {
h = handle_pool.obtain(device_id); {HIPBLAS_COMPUTE_32F, "f32"}};
set(device_id, h); static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map);
return h;
} static class GemmAlgoCache {
public:
void set(int device_id, hipblasLtHandle_t h) { struct Key {
if (d.empty()) { int deviceCap;
d.resize(handle_pool.get_size()); hipDataType a_type, b_type, d_type, bias_type;
} int m, n, k;
d[device_id] = h; int lda, ldb, ldd;
} hipblasOperation_t transa, transb;
hipblasLtEpilogue_t epilogue;
~HandleCache() {
if (!d.empty()) { Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, hipDataType d_type_,
handle_pool.store(d); hipDataType bias_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
} hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasLtEpilogue_t epilogue_)
} : deviceCap(deviceCap_),
a_type(a_type_),
private: b_type(b_type_),
std::vector<hipblasLtHandle_t> d; d_type(d_type_),
} cached_handles; bias_type(bias_type_),
m(m_),
class csv_helper { n(n_),
public: k(k_),
struct start {}; lda(lda_),
struct end {}; ldb(ldb_),
ldd(ldd_),
csv_helper(std::ostream& os, char sep_val) transa(transa_),
: m_os{os}, m_sep_val(sep_val), m_start(true), m_sep("") {} transb(transb_),
epilogue(epilogue_) {}
csv_helper& operator<<(const start&) {
m_start = true; Key() {}
return *this;
} bool operator==(const Key& val) const {
return ((deviceCap == val.deviceCap) && (a_type == val.a_type) && (b_type == val.b_type) &&
csv_helper& operator<<(const end&) { (d_type == val.d_type) && (bias_type == val.bias_type) && (m == val.m) &&
m_sep = ""; (n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) &&
m_start = false; (ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) &&
return *this; (epilogue == val.epilogue));
} }
template <typename T> struct Comp {
csv_helper& operator<<(const T& v) { bool operator()(const Key& lhs, const Key& rhs) const {
m_os << m_sep << v; return ::std::string_view((const char*)&lhs, sizeof(lhs)) <
if (m_start) { ::std::string_view((const char*)&rhs, sizeof(rhs));
m_start = false; }
m_sep = m_sep_val; };
} };
return *this;
} void init() {
std::lock_guard<std::mutex> lock(mt);
private: int device_count = 0;
std::ostream& m_os; NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
char m_sep_val; dev_cap.resize(device_count);
bool m_start; for (int i = 0; i < device_count; i++) {
std::string m_sep; hipDeviceProp_t prop;
}; NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, i));
dev_cap[i] = prop.major * 100 + prop.minor;
template <typename T> }
class NameMapper { load_();
public: save_();
NameMapper(const std::unordered_map<T, std::string_view>& name_map) : map(name_map) {} }
const std::string_view& getName(const T& val) { return map.at(val); }
T getValue(const std::string& name, const char* label = "", inline int device_cap(int device_id) {
std::function<bool(const T&)> filter = nullptr) { if (dev_cap.empty()) init();
for (auto iter = map.begin(); iter != map.end(); ++iter) { return dev_cap[device_id];
if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first; }
}
NVTE_ERROR("Invalid ", label, " name: ", name); struct Algo {
} std::optional<hipblasLtMatmulAlgo_t> algo;
int64_t algoId;
protected: int index;
const std::unordered_map<T, std::string_view>& map; size_t ws_size_min;
}; size_t ws_size_max;
Algo() : algo(), index(-1), algoId(), ws_size_min(0), ws_size_max(0) {}
static std::unordered_map<hipDataType, std::string_view> type_name_map = { Algo(int idx, int64_t id, size_t ws_min, size_t ws_max)
{HIP_R_32F, "float32"}, : algo(), index(idx), algoId(id), ws_size_min(ws_min), ws_size_max(ws_max) {}
{HIP_R_16F, "float16"}, inline bool hasId() { return index >= 0; }
{HIP_R_16BF, "bfloat16"}, const static inline int64_t getAlgoId(const hipblasLtMatmulAlgo_t& algo) {
{HIP_R_8F_E4M3_FNUZ, "float8e4m3"}, return *(const int64_t*)&algo;
{HIP_R_8F_E5M2_FNUZ, "float8e5m2"}, }
#if HIP_VERSION >= 60300000 };
{HIP_R_8F_E4M3, "float8e4m3"},
{HIP_R_8F_E5M2, "float8e5m2"}, bool find(const Key& cfg, size_t ws_size, Algo& algo) {
#endif std::lock_guard<std::mutex> lock(mt);
}; if (auto* pentry = find_(cfg, ws_size, ws_size); pentry != nullptr) {
static NameMapper<hipDataType> typeNameMapper(type_name_map); algo = *pentry;
return true;
static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = { }
{HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_T, "T"}}; return false;
static NameMapper<hipblasOperation_t> transposeNameMapper(trans_name_map); }
static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = { void store(const Key& cfg, const Algo& algo) {
{HIPBLASLT_EPILOGUE_DEFAULT, "-"}, {HIPBLASLT_EPILOGUE_BIAS, "bias"}, size_t ws_size_min = algo.ws_size_min;
{HIPBLASLT_EPILOGUE_GELU_AUX, "geluaux"}, {HIPBLASLT_EPILOGUE_GELU_AUX_BIAS, "geluauxbias"}, size_t ws_size_max = algo.ws_size_max;
{HIPBLASLT_EPILOGUE_DGELU, "dgelu"}, {HIPBLASLT_EPILOGUE_DGELU_BGRAD, "dgelubgrad"}, NVTE_CHECK(ws_size_max >= ws_size_min, "Invalid WS size");
{HIPBLASLT_EPILOGUE_BGRADB, "bgradb"}}; std::lock_guard<std::mutex> lock(mt);
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
//Remove overlapping with existing entries;
static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = { while (auto* pentry = find_(cfg, ws_size_min, ws_size_max)) {
{HIPBLAS_COMPUTE_32F, "f32"}}; if (pentry->ws_size_min <= ws_size_min && pentry->ws_size_max >= ws_size_max) {
static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map); *pentry = algo;
save_();
static class GemmAlgoCache { return;
public: }
struct Key {
int deviceCap; if (ws_size_max > pentry->ws_size_max) {
hipDataType a_type, b_type, d_type, bias_type; ws_size_min = pentry->ws_size_max + 1;
int m, n, k; } else if (ws_size_min < pentry->ws_size_min) {
int lda, ldb, ldd; ws_size_max = pentry->ws_size_min - 1;
hipblasOperation_t transa, transb; } else {
hipblasLtEpilogue_t epilogue; //Should never be here
NVTE_ERROR("Cannot merge WS size range");
Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, hipDataType d_type_, }
hipDataType bias_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_, }
hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasLtEpilogue_t epilogue_)
: deviceCap(deviceCap_), //Merge to adjusted entry if possible
a_type(a_type_), auto* pentry = find_(cfg, ws_size_min - 1, ws_size_min);
b_type(b_type_), if (pentry && pentry->algoId == algo.algoId) {
d_type(d_type_), pentry->algo = algo.algo;
bias_type(bias_type_), pentry->ws_size_max = ws_size_max;
m(m_), save_();
n(n_), } else {
k(k_), auto it = d.emplace(cfg, algo);
lda(lda_), it->second.ws_size_min = ws_size_min;
ldb(ldb_), it->second.ws_size_max = ws_size_max;
ldd(ldd_), save_(it->first, it->second);
transa(transa_), }
transb(transb_), }
epilogue(epilogue_) {}
protected:
Key() {} Algo* find_(const Key& cfg, size_t ws_min, size_t ws_max) {
const auto key_range = d.equal_range(cfg);
bool operator==(const Key& val) const { for (auto i = key_range.first; i != key_range.second; i++) {
return ((deviceCap == val.deviceCap) && (a_type == val.a_type) && (b_type == val.b_type) && if (ws_min <= i->second.ws_size_max && ws_max >= i->second.ws_size_min) {
(d_type == val.d_type) && (bias_type == val.bias_type) && (m == val.m) && return &i->second;
(n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) && }
(ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) && }
(epilogue == val.epilogue)); return nullptr;
} }
struct Comp { void header_(std::ostream& ofs) {
bool operator()(const Key& lhs, const Key& rhs) const { csv_helper fs(ofs, csv_sep);
return ::std::string_view((const char*)&lhs, sizeof(lhs)) < fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b"
::std::string_view((const char*)&rhs, sizeof(rhs)); << "type_a" << "type_b" << "type_d" << "bias_type"
} << "lda" << "ldb" << "ldd" << "epi" << "comp" << "scale"
}; << "ws_min" << "ws_max" << "algo_id" << "aidx";
}; }
void init() { void load_() {
std::lock_guard<std::mutex> lock(mt); const char* env = std::getenv("TE_HIPBLASLT_ALGO_LOAD");
int device_count = 0; if (env == nullptr || env[0] == '\0') {
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count)); return;
dev_cap.resize(device_count); }
for (int i = 0; i < device_count; i++) { std::ifstream ifs{env};
hipDeviceProp_t prop; if (!ifs.is_open()) {
NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, i)); std::cerr << "Could not load autotune results storage " << env << "\n";
dev_cap[i] = prop.major * 100 + prop.minor; return;
} }
load_(); std::cout << "Loading autotune results from " << env << "\n";
save_();
} Key cfg;
std::string line;
inline int device_cap(int device_id) { std::getline(ifs, line); // the first line with legend
if (dev_cap.empty()) init(); {
return dev_cap[device_id]; std::ostringstream hline;
} header_(hline);
if (hline.str() != line) {
struct Algo { std::cerr << "Incorrect algo storage legend. Expected " << hline.str() << "\n";
std::optional<hipblasLtMatmulAlgo_t> algo; return;
int64_t algoId; }
int index; }
size_t ws_size_min;
size_t ws_size_max; while (std::getline(ifs, line)) {
Algo() : algo(), index(-1), algoId(), ws_size_min(0), ws_size_max(0) {} line.erase(0, line.find_first_not_of(" \t\n\r\f\v"));
Algo(int idx, int64_t id, size_t ws_min, size_t ws_max) if (auto pos = line.find_last_not_of(" \t\n\r\f\v"); pos != std::string::npos) {
: algo(), index(idx), algoId(id), ws_size_min(ws_min), ws_size_max(ws_max) {} line.resize(pos + 1);
inline bool hasId() { return index >= 0; } }
const static inline int64_t getAlgoId(const hipblasLtMatmulAlgo_t& algo) { if (line.empty() || line[0] == '#') continue;
return *(const int64_t*)&algo; std::istringstream is(line);
} char c;
}; std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale;
int64_t algo_id;
bool find(const Key& cfg, size_t ws_size, Algo& algo) { int algo_idx;
std::lock_guard<std::mutex> lock(mt); size_t ws_min, ws_max;
if (auto* pentry = find_(cfg, ws_size, ws_size); pentry != nullptr) {
algo = *pentry; is >> std::skipws;
return true; is >> cfg.deviceCap >> c >> cfg.m >> c >> cfg.n >> c >> cfg.k >> c;
}
return false; //Filter out entries for devices not presented on the curent system
} bool b_found = false;
for (int i = 0; i < dev_cap.size(); i++) {
void store(const Key& cfg, const Algo& algo) { if (dev_cap[i] == cfg.deviceCap) {
size_t ws_size_min = algo.ws_size_min; b_found = true;
size_t ws_size_max = algo.ws_size_max; break;
NVTE_CHECK(ws_size_max >= ws_size_min, "Invalid WS size"); }
std::lock_guard<std::mutex> lock(mt); }
if (!b_found) continue;
//Remove overlapping with existing entries;
while (auto* pentry = find_(cfg, ws_size_min, ws_size_max)) { std::getline(is, trans_a, csv_sep);
if (pentry->ws_size_min <= ws_size_min && pentry->ws_size_max >= ws_size_max) { std::getline(is, trans_b, csv_sep);
*pentry = algo; std::getline(is, type_a, csv_sep);
save_(); std::getline(is, type_b, csv_sep);
return; std::getline(is, type_d, csv_sep);
} std::getline(is, bias_type, csv_sep);
is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c;
if (ws_size_max > pentry->ws_size_max) { std::getline(is, epi, csv_sep);
ws_size_min = pentry->ws_size_max + 1; std::getline(is, comp, csv_sep);
} else if (ws_size_min < pentry->ws_size_min) { std::getline(is, scale, csv_sep);
ws_size_max = pentry->ws_size_min - 1; is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx;
} else {
//Should never be here if (is.bad()) {
NVTE_ERROR("Cannot merge WS size range"); std::cerr << "Parsing CSV line failed: " << line << "\n";
} return;
} }
//Merge to adjusted entry if possible if (ws_min > ws_max) {
auto* pentry = find_(cfg, ws_size_min - 1, ws_size_min); std::cout << "[WARNING] Invalid WS size at " << line << "\n";
if (pentry && pentry->algoId == algo.algoId) { continue;
pentry->algo = algo.algo; }
pentry->ws_size_max = ws_size_max;
save_(); #if HIP_VERSION >= 60300000
} else { auto fp8_filter = [](const hipDataType& val) {
auto it = d.emplace(cfg, algo); return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
it->second.ws_size_min = ws_size_min; };
it->second.ws_size_max = ws_size_max; #else
save_(it->first, it->second); auto fp8_filter = nullptr;
} #endif
}
cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter);
protected: cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter);
Algo* find_(const Key& cfg, size_t ws_min, size_t ws_max) { cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter);
const auto key_range = d.equal_range(cfg); cfg.bias_type = (bias_type == "-")
for (auto i = key_range.first; i != key_range.second; i++) { ? (hipDataType)-1
if (ws_min <= i->second.ws_size_max && ws_max >= i->second.ws_size_min) { : typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
return &i->second;
} cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
} cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
return nullptr;
} cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
//Check and filter out compute and scale types
void header_(std::ostream& ofs) { if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F ||
csv_helper fs(ofs, csv_sep); typeNameMapper.getValue(scale, "scale") != HIP_R_32F) {
fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b" continue;
<< "type_a" << "type_b" << "type_d" << "bias_type" }
<< "lda" << "ldb" << "ldd" << "epi" << "comp" << "scale"
<< "ws_min" << "ws_max" << "algo_id" << "aidx"; if (find_(cfg, ws_min, ws_max)) {
} std::cout << "[WARNING] Duplicated/overlapped entry in algo cache\n";
continue;
void load_() { }
const char* env = std::getenv("TE_HIPBLASLT_ALGO_LOAD");
if (env == nullptr || env[0] == '\0') { d.emplace(cfg, Algo(algo_idx, algo_id, ws_min, ws_max));
return; }
} }
std::ifstream ifs{env};
if (!ifs.is_open()) { bool can_save_(bool reopen = false) {
std::cerr << "Could not load autotune results storage " << env << "\n"; if (!save_fs) {
return; const char* temp = std::getenv("TE_HIPBLASLT_ALGO_SAVE");
} if (temp == nullptr || temp[0] == '\0') {
std::cout << "Loading autotune results from " << env << "\n"; return false;
}
Key cfg;
std::string line; save_fs_name = temp;
std::getline(ifs, line); // the first line with legend
{ pid_t pid = getpid();
std::ostringstream hline;
header_(hline); size_t pos = 0;
if (hline.str() != line) { while ((pos = save_fs_name.find("%i", pos)) != std::string::npos) {
std::cerr << "Incorrect algo storage legend. Expected " << hline.str() << "\n"; save_fs_name.replace(pos, 2, std::to_string(pid));
return; }
}
} save_fs = std::make_unique<std::ofstream>();
std::cout << "Saving autotune results to " << save_fs_name << "\n";
while (std::getline(ifs, line)) { }
line.erase(0, line.find_first_not_of(" \t\n\r\f\v"));
if (auto pos = line.find_last_not_of(" \t\n\r\f\v"); pos != std::string::npos) { if (reopen) {
line.resize(pos + 1); if (save_fs->is_open()) {
} save_fs->close();
if (line.empty() || line[0] == '#') continue; }
std::istringstream is(line); save_fs->open(save_fs_name, std::ios_base::trunc);
char c; }
std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale;
int64_t algo_id; if (save_fs->is_open() && !save_fs->bad()) {
int algo_idx; return true;
size_t ws_min, ws_max; } else {
if (reopen) std::cerr << "Could not open autotune results storage " << save_fs_name << "\n";
is >> std::skipws; return false;
is >> cfg.deviceCap >> c >> cfg.m >> c >> cfg.n >> c >> cfg.k >> c; }
}
//Filter out entries for devices not presented on the curent system
bool b_found = false; void save_() {
for (int i = 0; i < dev_cap.size(); i++) { if (!can_save_(true)) {
if (dev_cap[i] == cfg.deviceCap) { return;
b_found = true; }
break; header_(*save_fs);
} *save_fs << "\n";
}
if (!b_found) continue; for (const auto& elem : d) {
save_(elem.first, elem.second);
std::getline(is, trans_a, csv_sep); }
std::getline(is, trans_b, csv_sep); }
std::getline(is, type_a, csv_sep);
std::getline(is, type_b, csv_sep); void save_(const Key& cfg, const Algo& algo) {
std::getline(is, type_d, csv_sep); if (!can_save_()) {
std::getline(is, bias_type, csv_sep); return;
is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c; }
std::getline(is, epi, csv_sep); csv_helper csv(*save_fs, csv_sep);
std::getline(is, comp, csv_sep); csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k << transposeNameMapper.getName(cfg.transa)
std::getline(is, scale, csv_sep); << transposeNameMapper.getName(cfg.transb) << typeNameMapper.getName(cfg.a_type)
is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
if (is.bad()) { << cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
std::cerr << "Parsing CSV line failed: " << line << "\n"; << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
return; << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end()
} << "\n";
}
if (ws_min > ws_max) {
std::cout << "[WARNING] Invalid WS size at " << line << "\n"; private:
continue; std::vector<int> dev_cap;
} constexpr static char csv_sep = ',';
std::unique_ptr<std::ofstream> save_fs;
#if HIP_VERSION >= 60300000 std::string save_fs_name;
auto fp8_filter = [](const hipDataType& val) { std::mutex mt;
return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ); /* Map of problem config to tuple of ws_size and Algo
}; * When searching, elements matching Key are filtered
#else * for requested WS size be between Algo.ws_size and pair.first
auto fp8_filter = nullptr; */
#endif std::multimap<Key, Algo, Key::Comp> d;
} algoCache;
cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter);
cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter); static inline int getIntEnv(const char* name, int defval, int minval) {
cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter); int val = defval;
cfg.bias_type = (bias_type == "-") const char* env = std::getenv(name);
? (hipDataType)-1 if (env != nullptr && env[0] != '\0') {
: typeNameMapper.getValue(bias_type, "bias_type", fp8_filter); val = atoi(env);
if (val < minval) {
cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a"); val = minval;
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b"); }
}
cfg.epilogue = epilogueNameMapper.getValue(epi, "epi"); return val;
//Check and filter out compute and scale types }
if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F ||
typeNameMapper.getValue(scale, "scale") != HIP_R_32F) { } //namespace
continue;
} static inline void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(handle));
if (find_(cfg, ws_min, ws_max)) { }
std::cout << "[WARNING] Duplicated/overlapped entry in algo cache\n";
continue; using hipBlasLtHandleManager = detail::HandleManager<hipblasLtHandle_t, CreateHipBlasLtHandle>;
}
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
d.emplace(cfg, Algo(algo_idx, algo_id, ws_min, ws_max)); using namespace transformer_engine;
} switch (t) {
} case HIP_R_16F:
return DType::kFloat16;
bool can_save_(bool reopen = false) { case HIP_R_32F:
if (!save_fs) { return DType::kFloat32;
const char* temp = std::getenv("TE_HIPBLASLT_ALGO_SAVE"); case HIP_R_16BF:
if (temp == nullptr || temp[0] == '\0') { return DType::kBFloat16;
return false; default:
} NVTE_ERROR("Invalid type");
}
save_fs_name = temp; }
pid_t pid = getpid(); void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
size_t pos = 0; int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
while ((pos = save_fs_name.find("%i", pos)) != std::string::npos) { bool grad, void* workspace, size_t workspaceSize, bool accumulate,
save_fs_name.replace(pos, 2, std::to_string(pid)); bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
} bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) {
void* A = inputA->data.dptr;
save_fs = std::make_unique<std::ofstream>(); void* A_scale_inverse = inputA->scale_inv.dptr;
std::cout << "Saving autotune results to " << save_fs_name << "\n"; float* A_scale_inverse_float = (float*)(inputA->scale_inv.dptr);
} void* B = inputB->data.dptr;
void* B_scale_inverse = inputB->scale_inv.dptr;
if (reopen) { float* B_scale_inverse_float = (float*)(inputB->scale_inv.dptr);
if (save_fs->is_open()) { void* D = outputD->data.dptr;
save_fs->close(); void* bias_ptr = inputBias->data.dptr;
} const bool bias = bias_ptr != nullptr;
save_fs->open(save_fs_name, std::ios_base::trunc); void* pre_gelu_out = outputPreGelu->data.dptr;
} const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
if (save_fs->is_open() && !save_fs->bad()) { const bool use_int8 = is_int8_dtype(inputA->data.dtype) || is_int8_dtype(inputB->data.dtype);
return true; const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
} else { const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
if (reopen) std::cerr << "Could not open autotune results storage " << save_fs_name << "\n"; const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
return false; const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
}
} NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
void save_() { NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
if (!can_save_(true)) { "FP8 input to GEMM requires inverse of scale!");
return; NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
} "INT8 input to GEMM requires inverse of scale!");
header_(*save_fs); NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
*save_fs << "\n"; "INT8 input to GEMM requires inverse of scale!");
for (const auto& elem : d) { bool tensorwise_int8 = 0;;
save_(elem.first, elem.second); const char* NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
} if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1;
}
// check consistency of arguments:
void save_(const Key& cfg, const Algo& algo) { // if fp8 is desired, context cannot be null
if (!can_save_()) { // fp8 + gelu fusion + fp8 aux is unavailable right now.
return; if (use_fp8 || use_int8) {
} NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
csv_helper csv(*save_fs, csv_sep); }
csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k << transposeNameMapper.getName(cfg.transa) float one = 1.0;
<< transposeNameMapper.getName(cfg.transb) << typeNameMapper.getName(cfg.a_type) float zero = 0.0;
<< typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type) float beta = (accumulate) ? one : zero;
<< ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue) int device_id;
<< computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F) NVTE_CHECK_CUDA(hipGetDevice(&device_id));
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end()
<< "\n"; hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
}
hipblasLtMatmulDesc_t operationDesc = nullptr;
private: hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
std::vector<int> dev_cap; hipblasLtMatmulPreference_t preference = nullptr;
constexpr static char csv_sep = ','; hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
std::unique_ptr<std::ofstream> save_fs;
std::string save_fs_name; int64_t ld_gelumat = (int64_t)ldd;
std::mutex mt;
/* Map of problem config to tuple of ws_size and Algo // default to tf32 except for e5m2 inputs where the config is not supported
* When searching, elements matching Key are filtered hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
* for requested WS size be between Algo.ws_size and pair.first
*/ // Create matrix descriptors. Not setting any extra attributes.
std::multimap<Key, Algo, Key::Comp> d; NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, transa == HIPBLAS_OP_N ? m : k,
} algoCache; transa == HIPBLAS_OP_N ? k : m, lda));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type, transb == HIPBLAS_OP_N ? k : n,
static inline int getIntEnv(const char* name, int defval, int minval) { transb == HIPBLAS_OP_N ? n : k, ldb));
int val = defval; NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0') { NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
val = atoi(env); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
if (val < minval) { &transa, sizeof(transa)));
val = minval; NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
} &transb, sizeof(transb)));
}
return val; // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
} // Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
} //namespace if (use_fp8) {
// Split accumulator.
/* Warning: only call once per device! const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend /*
* need to create multiple handles corresponding to compute_streams NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
* to avoid a handle be used by multi-streams concurrently. HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
*/ &fastAccuMode,
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) { sizeof(fastAccuMode)));
NVTE_CHECK(hipblaslt_handles != nullptr); */
for (int i = 0; i < compute_num_streams; i++) { NVTE_CHECK_HIPBLASLT(
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i])); hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
} &A_scale_inverse, sizeof(A_scale_inverse)));
} NVTE_CHECK_HIPBLASLT(
hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) { &B_scale_inverse, sizeof(B_scale_inverse)));
using namespace transformer_engine; if (bias) {
switch (t) { NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
case HIP_R_16F: operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
return DType::kFloat16; }
case HIP_R_32F: }
return DType::kFloat32; if (tensorwise_int8) {
case HIP_R_16BF: NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
return DType::kBFloat16; HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
default: (void*)&A_scale_inverse_float,
NVTE_ERROR("Invalid type"); sizeof(void*)));
} NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
} HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
(void*)&B_scale_inverse_float,
void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, sizeof(void*)));
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda, }
int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
bool grad, void* workspace, size_t workspaceSize, bool accumulate, if (bias && gelu) {
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, if (grad) {
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream, epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD;
hipblasLtHandle_t handle) { } else {
void* A = inputA->data.dptr; epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
void* A_scale_inverse = inputA->scale_inv.dptr; }
float* A_scale_inverse_float = (float*)(inputA->scale_inv.dptr); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
void* B = inputB->data.dptr; operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
void* B_scale_inverse = inputB->scale_inv.dptr; NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
float* B_scale_inverse_float = (float*)(inputB->scale_inv.dptr); HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
void* D = outputD->data.dptr; &pre_gelu_out, sizeof(pre_gelu_out)));
void* bias_ptr = inputBias->data.dptr; NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
const bool bias = bias_ptr != nullptr; operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
void* pre_gelu_out = outputPreGelu->data.dptr; } else if (bias) {
const bool gelu = pre_gelu_out != nullptr; if (tensorwise_int8) {
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); if (grad) {
const bool use_int8 = is_int8_dtype(inputA->data.dtype) || is_int8_dtype(inputB->data.dtype); int batch_size = k;
const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype); int output_dim = n;
const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype); DType te_bias_dtype = get_transformer_engine_dtype_from_hipblaslt_dtype(bias_type);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype); te_bias_dtype, BType,·
detail::tensorwise_int8_bias_gradient_kernelLauncher<BType>(
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, reinterpret_cast<const int8_t*>(B), reinterpret_cast<BType*>(bias_ptr), B_scale_inverse_float, batch_size,
"FP8 input to GEMM requires inverse of scale!"); output_dim, stream););
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr, } else {
"FP8 input to GEMM requires inverse of scale!"); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
"INT8 input to GEMM requires inverse of scale!"); epilogue = HIPBLASLT_EPILOGUE_BIAS;
NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
"INT8 input to GEMM requires inverse of scale!"); operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
}
bool tensorwise_int8 = 0;; } else {
const char* NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE"); if (grad) {
if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1; // grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB;
// check consistency of arguments: } else {
// if fp8 is desired, context cannot be null epilogue = HIPBLASLT_EPILOGUE_BIAS;
// fp8 + gelu fusion + fp8 aux is unavailable right now. }
if (use_fp8 || use_int8) { NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!"); operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
} }
float one = 1.0;
float zero = 0.0; } else if (gelu) {
float beta = (accumulate) ? one : zero; if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU;
int device_id; } else {
NVTE_CHECK_CUDA(hipGetDevice(&device_id)); epilogue = HIPBLASLT_EPILOGUE_GELU_AUX;
}
if (handle == nullptr) { NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
handle = cached_handles.get(device_id); HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
if (handle == nullptr) { &pre_gelu_out, sizeof(pre_gelu_out)));
handle = cached_handles.obtain(device_id); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
} operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
} }
hipblasLtMatmulDesc_t operationDesc = nullptr; NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipDataType)-1, m, n, k, lda, ldb, ldd, transa,
int64_t ld_gelumat = (int64_t)ldd; transb, epilogue);
GemmAlgoCache::Algo cached_algo;
// default to tf32 except for e5m2 inputs where the config is not supported if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) {
hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F; int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0);
int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0);
// Create matrix descriptors. Not setting any extra attributes. int algoTuneCount = 1;
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, transa == HIPBLAS_OP_N ? m : k, std::vector<hipblasLtMatmulHeuristicResult_t> algoArr;
transa == HIPBLAS_OP_N ? k : m, lda)); bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0;
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type, transb == HIPBLAS_OP_N ? k : n,
transb == HIPBLAS_OP_N ? n : k, ldb)); if (tuneLoopCount) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); /* HIPBLASLT may return hundreds of algos for some configs
* Limit amount by default. User may override with env
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F)); */
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA, static const int defaultAlgoCount = 16;
&transa, sizeof(transa))); algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB, }
&transb, sizeof(transb))); algoTuneCount += firstAlgo;
int algoTotalCount =
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount;
// Note: gelu fusion isn't available right now, and we don't need algoArr.resize(algoTotalCount);
// amax(D) either (next op is high precision).
if (use_fp8) { NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference));
// Split accumulator. NVTE_CHECK_HIPBLASLT(
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; hipblasLtMatmulPreferenceSetAttribute(preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
/* &workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
&fastAccuMode, Ddesc, preference, algoTotalCount,
sizeof(fastAccuMode))); algoArr.data(), &algoTotalCount));
*/ algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(
hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference));
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_HIPBLASLT( //If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, if (cached_algo.hasId()) {
&B_scale_inverse, sizeof(B_scale_inverse))); int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0;
if (bias) { for (int i = 0; i < algoTotalCount; i++) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( const auto& algo = algoArr[idx];
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type))); if (algo.state == HIPBLAS_STATUS_SUCCESS) {
} if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo)) {
} cached_algo.algo = algo.algo;
if (tensorwise_int8) { if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, cached_algo.ws_size_min = algo.workspaceSize;
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, cached_algo.index = idx;
(void*)&A_scale_inverse_float, algoCache.store(gemm_cfg, cached_algo);
sizeof(void*))); }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, break;
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, }
(void*)&B_scale_inverse_float, }
sizeof(void*))); idx = (idx + 1) % algoTotalCount;
} }
if (logTuning && !cached_algo.algo.has_value()) {
if (bias && gelu) { std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId
if (grad) { << " in hipBLASLt results" << std::endl;
epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD; }
} else { }
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
} //No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( if (!cached_algo.algo.has_value()) {
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); int bestAlgo = -1;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, algoTuneCount = std::min(algoTuneCount, algoTotalCount);
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, if (tuneLoopCount > 0) {
&pre_gelu_out, sizeof(pre_gelu_out))); if (logTuning)
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat))); << " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
} else if (bias) { << tuneLoopCount << " loops " << std::endl;
if (tensorwise_int8) {
if (grad) { NVTE_CHECK_CUDA(hipStreamSynchronize(stream));
int batch_size = k; hipStream_t profilingStream;
int output_dim = n; NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
DType te_bias_dtype = get_transformer_engine_dtype_from_hipblaslt_dtype(bias_type); using tuning_clock = std::chrono::steady_clock;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tuning_clock::now(); //the first call takes little longer so do it outside the loop
te_bias_dtype, BType,· tuning_clock::duration bestTime = tuning_clock::duration::max();
detail::tensorwise_int8_bias_gradient_kernelLauncher<BType>(
reinterpret_cast<const int8_t*>(B), reinterpret_cast<BType*>(bias_ptr), B_scale_inverse_float, batch_size, for (int algo = firstAlgo; algo < algoTuneCount; algo++) {
output_dim, stream);); if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS) {
} else { continue;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( }
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type))); // Warm-up call
epilogue = HIPBLASLT_EPILOGUE_BIAS; NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( static_cast<const void*>(&one), /* alpha */
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); A, /* A */
} Adesc, B, /* B */
} else { Bdesc, static_cast<const void*>(&beta), /* beta */
if (grad) { D, /* C */
// grad output is always input B Ddesc, D, /* D */
epilogue = HIPBLASLT_EPILOGUE_BGRADB; Ddesc, &algoArr[algo].algo, /* algo */
} else { workspace, /* workspace */
epilogue = HIPBLASLT_EPILOGUE_BIAS; workspaceSize, profilingStream)); /* stream */
} NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); //Profiling loop
} tuning_clock::time_point startTime = tuning_clock::now();
for (int loop = 0; loop < tuneLoopCount; loop++) {
} else if (gelu) { NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
if (grad) { static_cast<const void*>(&one), /* alpha */
epilogue = HIPBLASLT_EPILOGUE_DGELU; A, /* A */
} else { Adesc, B, /* B */
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX; Bdesc, static_cast<const void*>(&beta), /* beta */
} D, /* C */
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, Ddesc, D, /* D */
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, Ddesc, &algoArr[algo].algo, /* algo */
&pre_gelu_out, sizeof(pre_gelu_out))); workspace, /* workspace */
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( workspaceSize, profilingStream)); /* stream */
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat))); }
} NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
tuning_clock::duration algoTime = tuning_clock::now() - startTime;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( if (algoTime < bestTime) {
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); bestAlgo = algo;
bestTime = algoTime;
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, }
use_fp8 ? bias_type : (hipDataType)-1, m, n, k, lda, ldb, ldd, transa, }
transb, epilogue);
GemmAlgoCache::Algo cached_algo; NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream));
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) { if (bestAlgo >= 0) {
int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0); if (logTuning)
int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0); std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time "
int algoTuneCount = 1; << std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() /
std::vector<hipblasLtMatmulHeuristicResult_t> algoArr; tuneLoopCount
bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0; << " ns" << std::endl;
}
if (tuneLoopCount) { } else if (firstAlgo < algoTuneCount) {
/* HIPBLASLT may return hundreds of algos for some configs bestAlgo = firstAlgo;
* Limit amount by default. User may override with env }
*/
static const int defaultAlgoCount = 16; if (bestAlgo < 0) {
algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
} NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
algoTuneCount += firstAlgo; NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
int algoTotalCount = NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount; throw std::runtime_error("Unable to find any suitable algorithms");
algoArr.resize(algoTotalCount); }
cached_algo.algo = algoArr[bestAlgo].algo;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference)); cached_algo.index = bestAlgo;
NVTE_CHECK_HIPBLASLT( cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo);
hipblasLtMatmulPreferenceSetAttribute(preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize;
&workspaceSize, sizeof(workspaceSize))); cached_algo.ws_size_max = workspaceSize;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc, if (logTuning)
Ddesc, preference, algoTotalCount, std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId
algoArr.data(), &algoTotalCount)); << std::endl;
algoArr.resize(algoTotalCount);
algoCache.store(gemm_cfg, cached_algo);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference)); }
}
//If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
if (cached_algo.hasId()) { // D = alpha * (A * B) + beta * C
int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0; NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
for (int i = 0; i < algoTotalCount; i++) { static_cast<const void*>(&one), /* alpha */
const auto& algo = algoArr[idx]; A, /* A */
if (algo.state == HIPBLAS_STATUS_SUCCESS) { Adesc, B, /* B */
if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo)) { Bdesc, static_cast<const void*>(&beta), /* beta */
cached_algo.algo = algo.algo; D, /* C */
if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index) { Ddesc, D, /* D */
cached_algo.ws_size_min = algo.workspaceSize; Ddesc, &cached_algo.algo.value(), /* algo */
cached_algo.index = idx; workspace, /* workspace */
algoCache.store(gemm_cfg, cached_algo); workspaceSize, stream)); /* stream */
}
break; NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
} NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
} NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
idx = (idx + 1) % algoTotalCount; NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
} }
if (logTuning && !cached_algo.algo.has_value()) {
std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId struct HipBlasLtUserArgsDeleter {
<< " in hipBLASLt results" << std::endl; void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept {
} hipFree(ptr);
} }
};
//No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
if (!cached_algo.algo.has_value()) { using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>;
int bestAlgo = -1;
algoTuneCount = std::min(algoTuneCount, algoTotalCount); inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) {
if (tuneLoopCount > 0) { hipblaslt_ext::UserArguments* raw_ptr = nullptr;
if (logTuning) if (host) {
std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
<< " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with " } else {
<< tuneLoopCount << " loops " << std::endl; NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
}
NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); return HipBlasLtUserArgsPtr(raw_ptr);
hipStream_t profilingStream; }
NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
using tuning_clock = std::chrono::steady_clock; inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) {
tuning_clock::now(); //the first call takes little longer so do it outside the loop thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
tuning_clock::duration bestTime = tuning_clock::duration::max(); thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache;
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache;
for (int algo = firstAlgo; algo < algoTuneCount; algo++) { auto size_it = user_args_cache.find(size);
if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS) { if (size_it != user_args_cache.end()) {
continue; return size_it->second.get();
} }
// Warm-up call else
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, {
static_cast<const void*>(&one), /* alpha */ HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host);
A, /* A */ hipblaslt_ext::UserArguments* raw_ptr = user_args.get();
Adesc, B, /* B */ user_args_cache[size] = std::move(user_args);
Bdesc, static_cast<const void*>(&beta), /* beta */ return raw_ptr;
D, /* C */ }
Ddesc, D, /* D */ }
Ddesc, &algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize, profilingStream)); /* stream */ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream)); std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b,
//Profiling loop hipblasOperation_t transa, hipblasOperation_t transb, void* workspace,
tuning_clock::time_point startTime = tuning_clock::now(); size_t workspaceSize, bool accumulate, bool use_split_accumulator,
for (int loop = 0; loop < tuneLoopCount; loop++) { int math_sm_count, hipStream_t stream, int compute_stream_offset = 0) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, // Check compute_stream_offset valid.
static_cast<const void*>(&one), /* alpha */ NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
A, /* A */
Adesc, B, /* B */ hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
Bdesc, static_cast<const void*>(&beta), /* beta */ hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
D, /* C */
Ddesc, D, /* D */ // hipblaslt_ext::UserArguments* userArgs;
Ddesc, &algoArr[algo].algo, /* algo */ // NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
workspace, /* workspace */
workspaceSize, profilingStream)); /* stream */ hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
}
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream)); const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
tuning_clock::duration algoTime = tuning_clock::now() - startTime; const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
if (algoTime < bestTime) { const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype);
bestAlgo = algo;
bestTime = algoTime; hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
}
} float one = 1.0;
float zero = 0.0;
NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream)); float beta = (accumulate) ? one : zero;
if (bestAlgo >= 0) { int int_one = 1;
if (logTuning) int int_zero = 0;
std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time " int int_beta = int_zero;
<< std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() / bool use_int8 = false;
tuneLoopCount
<< " ns" << std::endl; if ((A_type == HIP_R_8I) && (B_type == HIP_R_8I) && (D_type == HIP_R_32I)) {
} NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate.");
} else if (firstAlgo < algoTuneCount) { use_int8 = true;
bestAlgo = firstAlgo; computeType = HIPBLAS_COMPUTE_32I;
} }
if (bestAlgo < 0) { hipblaslt_ext::GemmPreference gemmPref;
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc)); gemmPref.setMaxWorkspaceBytes(workspaceSize);
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc)); hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type,
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc)); computeType);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
throw std::runtime_error("Unable to find any suitable algorithms"); std::vector<hipblaslt_ext::GemmEpilogue> epilogue{
} hipblaslt_ext::
cached_algo.algo = algoArr[bestAlgo].algo; GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
cached_algo.index = bestAlgo; std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo); for (int i = 0; i < m.size(); i++) {
cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize; inputs[i].a = inputA[i]->data.dptr;
cached_algo.ws_size_max = workspaceSize; inputs[i].b = inputB[i]->data.dptr;
inputs[i].c = outputD[i]->data.dptr;
if (logTuning) inputs[i].d = outputD[i]->data.dptr;
std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
<< std::endl; inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
}
algoCache.store(gemm_cfg, cached_algo); // hipblaslt_ext::GemmEpilogue supports broadcasting
} groupedgemm.setProblem(m, n, k, b, epilogue, inputs);
}
const int request_solutions = 1;
// D = alpha * (A * B) + beta * C std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
static_cast<const void*>(&one), /* alpha */
A, /* A */ if (heuristicResult.empty()) {
Adesc, B, /* B */ std::cerr << "No valid solution found!" << std::endl;
Bdesc, static_cast<const void*>(&beta), /* beta */ return;
D, /* C */ }
Ddesc, D, /* D */
Ddesc, &cached_algo.algo.value(), /* algo */ // Make sure to initialize everytime the algo changes
workspace, /* workspace */ NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
workspaceSize, stream)); /* stream */
// Get the default values from the grouepdgemm object
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc)); groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc)); // Copy them to device memory
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc)); // hipblaslt_ext::UserArguments* d_userArgs;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc)); // NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
} NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice));
class userArgsManager { NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
public: // NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
userArgsManager() {} // NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
~userArgsManager() { // NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// Release all userArgs when the manager is destroyed // NVTE_CHECK_CUDA(hipFree(userArgs));
for (auto& device_pair : userArgs_map_) { }
hipFree(device_pair.second); // Only one userArgs per device
} #endif //USE_HIPBLASLT
}
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) { inline void CreateRocblasHandle(rocblas_handle* handle) {
std::lock_guard<std::mutex> lock(mutex_); NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle));
}
// Check if the userArgs for this device exists
auto device_it = userArgs_map_.find(device_id); using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>;
if (device_it != userArgs_map_.end()) { void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
return device_it->second; const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
} int ldb, int ldd, rocblas_operation transa, rocblas_operation transb, bool grad,
void* workspace, size_t workspaceSize, bool accumulate,
// Create a new userArgs for this device if it doesn't exist bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
hipblaslt_ext::UserArguments* userArgs; bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) {
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments))); void* A = inputA->data.dptr;
void* A_scale_inverse = inputA->scale_inv.dptr;
// Store the userArgs in the map for this device void* B = inputB->data.dptr;
userArgs_map_[device_id] = userArgs; void* B_scale_inverse = inputB->scale_inv.dptr;
return userArgs; void* C = outputD->data.dptr;
} void* D = outputD->data.dptr;
void* D_scale = outputD->scale.dptr;
private: void* D_amax = outputD->amax.dptr;
std::unordered_map<int, hipblaslt_ext::UserArguments*> void* bias_ptr = inputBias->data.dptr;
userArgs_map_; // Map from device_id to hipblasHandle const bool bias = bias_ptr != nullptr;
std::mutex mutex_; void* pre_gelu_out = outputPreGelu->data.dptr;
}; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
class d_userArgsManager { const rocblas_datatype A_type = get_rocblas_dtype(inputA->data.dtype);
public: const rocblas_datatype B_type = get_rocblas_dtype(inputB->data.dtype);
d_userArgsManager() {} const rocblas_datatype D_type = get_rocblas_dtype(outputD->data.dtype);
const rocblas_datatype bias_type = get_rocblas_dtype(inputBias->data.dtype);
~d_userArgsManager() { const rocblas_datatype gelu_type = get_rocblas_dtype(outputPreGelu->data.dtype);
// Release all userArgs when the manager is destroyed
for (auto& device_pair : d_userArgs_map_) { // check consistency of arguments:
hipFree(device_pair.second); // Only one userArgs per device // if fp8 is desired, context cannot be null
} // fp8 + gelu fusion + fp8 aux is unavailable right now.
} if (use_fp8 && gelu) {
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
// Get a userArgs for the given device (creates if necessary) "fp8 Aux output for gemm + gelu fusion not supported!");
hipblaslt_ext::UserArguments* get(int device_id, size_t size) { }
std::lock_guard<std::mutex> lock(mutex_); if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
// Check if the userArgs for this device exists }
auto device_it = d_userArgs_map_.find(device_id); // fp8 + grad unavailable in upstream
if (device_it != d_userArgs_map_.end()) { NVTE_CHECK(!(use_fp8 && grad), "fp8 + grad not supported!");
return device_it->second;
} float one = 1.0;
float zero = 0.0;
// Create a new userArgs for this device if it doesn't exist float beta = (accumulate) ? one : zero;
hipblaslt_ext::UserArguments* d_userArgs;
NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments))); float alpha = 1.0;
if (use_fp8) {
// Store the userArgs in the map for this device float A_scale_inv, B_scale_inv;
d_userArgs_map_[device_id] = d_userArgs; (void)hipMemcpy(&A_scale_inv, A_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
return d_userArgs; (void)hipMemcpy(&B_scale_inv, B_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
} alpha = A_scale_inv * B_scale_inv;
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*> rocblas_handle handle = rocblasHandleManager::Instance().GetHandle();
d_userArgs_map_; // Map from device_id to hipblasHandle NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream));
std::mutex mutex_;
}; // extract the stream order alloc env
bool stream_order_alloc = false;
// Define a static userArgs manager if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC")) {
static userArgsManager UAManager; if (env_p == nullptr || std::string(env_p) == "1") stream_order_alloc = true;
static d_userArgsManager d_UAManager; }
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, int64_t ld_gelumat = (int64_t)ldd;
std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, NVTE_CHECK((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
hipblasOperation_t transa, hipblasOperation_t transb, void* workspace, D_type == rocblas_datatype_f16_r) ||
size_t workspaceSize, bool accumulate, bool use_split_accumulator, (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
int math_sm_count, hipStream_t stream, int compute_stream_offset = 0) { D_type == rocblas_datatype_f32_r) ||
// Check compute_stream_offset valid. (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); D_type == rocblas_datatype_bf16_r) ||
(A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
int device_id; D_type == rocblas_datatype_f32_r) ||
hipGetDevice(&device_id); (A_type == rocblas_datatype_f32_r && B_type == rocblas_datatype_f32_r &&
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size()); D_type == rocblas_datatype_f32_r) ||
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size()); (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
D_type == rocblas_datatype_f32_r) ||
// hipblaslt_ext::UserArguments* userArgs; (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments))); D_type == rocblas_datatype_f16_r) ||
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
hipblasLtHandle_t handle = nullptr; D_type == rocblas_datatype_bf16_r) ||
if (compute_stream_offset != -1) { (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
// Init hipblaslt handles (once, globally) D_type == rocblas_datatype_f8_r) ||
static std::once_flag init_flag; (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams]; D_type == rocblas_datatype_bf8_r) ||
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles); (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
D_type == rocblas_datatype_f32_r) ||
handle = hipblaslt_handles[compute_stream_offset]; (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
} D_type == rocblas_datatype_f16_r) ||
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype); D_type == rocblas_datatype_bf16_r) ||
const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype); (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype); D_type == rocblas_datatype_f8_r) ||
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; D_type == rocblas_datatype_bf8_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
float one = 1.0; D_type == rocblas_datatype_f32_r) ||
float zero = 0.0; (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
float beta = (accumulate) ? one : zero; D_type == rocblas_datatype_f16_r) ||
int int_one = 1; (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
int int_zero = 0; D_type == rocblas_datatype_bf16_r) ||
int int_beta = int_zero; (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
bool use_int8 = false; D_type == rocblas_datatype_f8_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
if ((A_type == HIP_R_8I) && (B_type == HIP_R_8I) && (D_type == HIP_R_32I)) { D_type == rocblas_datatype_bf8_r),
NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate."); "Only the following combinations of data types are enabled now!\n\
use_int8 = true; 1. input: fp32, output: fp32.\n\
computeType = HIPBLAS_COMPUTE_32I; 2. input: fp16, output: fp16.\n\
} 3. input: bf16, output: bf16.\n\
4. input: fp8/bf8, output: fp8/bf8, fp16/bf16, fp32");
hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(workspaceSize); //If D is not fp32, then we need a temp buffer for GEMM result before applying epilogues. Otherwise, we can apply epilogues in-place.
hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type, // with bias or gelu, allocate fp32 D_temp if the output is not fp32
computeType); // with input fp8/bf8 (use_fp8) and bf16 output, need a fp32 D_temp, as rocblas does not support this case (fp8/bf8 input fp16/fp32 output is supported)
// with use_fp8 true and fp8/bf8 output, need fp32 D_temp to support amax and scale operation
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{ void* D_temp;
hipblaslt_ext:: if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
std::vector<hipblaslt_ext::GemmInputs> inputs(m.size()); D_type == rocblas_datatype_bf8_r))) {
for (int i = 0; i < m.size(); i++) { if (!stream_order_alloc) {
inputs[i].a = inputA[i]->data.dptr; NVTE_CHECK_CUDA(hipMalloc(&D_temp, sizeof(float) * m * n));
inputs[i].b = inputB[i]->data.dptr; } else {
inputs[i].c = outputD[i]->data.dptr; NVTE_CHECK_CUDA(hipMallocAsync(&D_temp, sizeof(float) * m * n, stream));
inputs[i].d = outputD[i]->data.dptr; }
inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one); } else {
inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta); D_temp = D;
} }
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm.setProblem(m, n, k, b, epilogue, inputs); // When Ti=To=fp16 and there is no bias or gelu, D_temp points to D and we would like it to be fp16
rocblas_datatype D_temp_type = rocblas_datatype_f32_r;
const int request_solutions = 1; if (!(bias || gelu) && (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult; D_type == rocblas_datatype_f16_r)) {
NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); D_temp_type = rocblas_datatype_f16_r;
}
if (heuristicResult.empty()) { // When Ti=To=bf16 and there is no bias or gelu, D_temp points to D and we would like it to be bf16
std::cerr << "No valid solution found!" << std::endl; if (!(bias || gelu) && (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
return; D_type == rocblas_datatype_bf16_r)) {
} D_temp_type = rocblas_datatype_bf16_r;
}
// Make sure to initialize everytime the algo changes // When Ti in fp8 or bf8, To=fp16, there is no bias or gelu, D_temp points to D and we would like it to be fp16, as rocblas support this case.
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace)); if ((!(bias || gelu)) && (use_fp8 && D_type == rocblas_datatype_f16_r)) {
D_temp_type = rocblas_datatype_f16_r;
// Get the default values from the grouepdgemm object }
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory if (accumulate && (D_temp != D || D_temp_type != D_type)) {
// hipblaslt_ext::UserArguments* d_userArgs; DType output_dtype = get_transformer_engine_dtype(D_type);
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream)); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), output_dtype, OType,
hipMemcpyHostToDevice)); //D_temp allocated only with fp32
detail::identity_kernelLauncher<OType, float>(
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream)); reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(D_temp), m * n, stream););
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream)); }
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// D = alpha * (A * B) + beta * C
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream)); if (use_fp8) {
// NVTE_CHECK_CUDA(hipFree(userArgs)); rocblas_computetype computeType = rocblas_compute_type_f32;
} NVTE_CHECK_ROCBLAS(rocblas_gemm_ex3(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B,
B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp,
#endif //USE_HIPBLASLT D_temp_type, ldd, computeType,
rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, 0));
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion } else {
rocblas_datatype computeType = rocblas_datatype_f32_r;
inline void CreateRocblasHandle(rocblas_handle* handle) { uint32_t flags = rocblas_gemm_flags_none;
NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle)); if ((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r) && grad) {
} flags = rocblas_gemm_flags_fp16_alt_impl;
}
using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>; NVTE_CHECK_ROCBLAS(rocblas_gemm_ex(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B,
void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp,
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda, D_temp_type, ldd, computeType,
int ldb, int ldd, rocblas_operation transa, rocblas_operation transb, bool grad, rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, flags));
void* workspace, size_t workspaceSize, bool accumulate, }
bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) { int batch_size, input_dim, output_dim;
void* A = inputA->data.dptr; if (bias && gelu) {
void* A_scale_inverse = inputA->scale_inv.dptr; if (grad) {
void* B = inputB->data.dptr; // epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
void* B_scale_inverse = inputB->scale_inv.dptr; // Apply GELU gradient to D_temp and store in D
void* C = outputD->data.dptr; // Apply bias gradient to D (D is already the result of GELU gradient) and store in bias_ptr;
void* D = outputD->data.dptr; // This case is NN
void* D_scale = outputD->scale.dptr; // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
void* D_amax = outputD->amax.dptr; // The bias vector length is m. So it will be reduced along axis 0 in row major
void* bias_ptr = inputBias->data.dptr; // (TODO): The cublasLt doc is not very clear wrt the bias gradient here.
const bool bias = bias_ptr != nullptr; // It does not explicitly say that it goes through GELU gradient first. We will need to
void* pre_gelu_out = outputPreGelu->data.dptr; // confirm in the future. As of now, my implementation for the bias gradient takes
const bool gelu = pre_gelu_out != nullptr; // the GELU gradient result in lower precision (D). It might be better to take the GELU
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); // gradient result in fp32 but as it requires some kernel changes I would only do that
const rocblas_datatype A_type = get_rocblas_dtype(inputA->data.dtype); // once we confirm that this is the right form of the epilogue.
const rocblas_datatype B_type = get_rocblas_dtype(inputB->data.dtype); // This is for linear1 -> gelu -> linear2
const rocblas_datatype D_type = get_rocblas_dtype(outputD->data.dtype); // compute dX = dY * W for linear2
const rocblas_datatype bias_type = get_rocblas_dtype(inputBias->data.dtype); // gemm_ex(A=W, B=dY)
const rocblas_datatype gelu_type = get_rocblas_dtype(outputPreGelu->data.dtype); batch_size = n;
input_dim =
// check consistency of arguments: m; // input dimension of the second linear layer is the output dimension of the first linear layer
// if fp8 is desired, context cannot be null output_dim = k;
// fp8 + gelu fusion + fp8 aux is unavailable right now. DType output_dtype = get_transformer_engine_dtype(D_type);
if (use_fp8 && gelu) { DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
"fp8 Aux output for gemm + gelu fusion not supported!"); output_dtype, OType,
} TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
if (is_fp8_dtype(outputD->data.dtype)) { gelu_dtype, GType,
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); detail::gelu_backward_kernelLauncher<OType, GType>(
} reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
// fp8 + grad unavailable in upstream reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
NVTE_CHECK(!(use_fp8 && grad), "fp8 + grad not supported!");
void* bias_tmp;
float one = 1.0; if (bias_type != rocblas_datatype_f32_r) {
float zero = 0.0; if (!stream_order_alloc) {
float beta = (accumulate) ? one : zero; NVTE_CHECK_CUDA(hipMalloc(
&bias_tmp,
float alpha = 1.0; sizeof(float) * input_dim)); // The bias gradient is for the first linear layer
if (use_fp8) { } else {
float A_scale_inv, B_scale_inv; NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * input_dim, stream));
(void)hipMemcpy(&A_scale_inv, A_scale_inverse, sizeof(float), hipMemcpyDeviceToHost); }
(void)hipMemcpy(&B_scale_inv, B_scale_inverse, sizeof(float), hipMemcpyDeviceToHost); } else {
alpha = A_scale_inv * B_scale_inv; bias_tmp = bias_ptr;
} }
rocblas_handle handle = rocblasHandleManager::Instance().GetHandle(); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream)); output_dtype, OType,
detail::bias_gradient_kernelLauncher<OType>(
// extract the stream order alloc env reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(bias_tmp), batch_size,
bool stream_order_alloc = false; input_dim, stream_order_alloc, stream););
if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC")) {
if (env_p == nullptr || std::string(env_p) == "1") stream_order_alloc = true; if (bias_type != rocblas_datatype_f32_r) {
} DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
int64_t ld_gelumat = (int64_t)ldd; bias_dtype, BType,
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
NVTE_CHECK((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r && reinterpret_cast<BType*>(bias_ptr),
D_type == rocblas_datatype_f16_r) || input_dim, stream););
(A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r && if (!stream_order_alloc) {
D_type == rocblas_datatype_f32_r) || NVTE_CHECK_CUDA(hipFree(bias_tmp));
(A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r && } else {
D_type == rocblas_datatype_bf16_r) || NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
(A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r && }
D_type == rocblas_datatype_f32_r) || }
(A_type == rocblas_datatype_f32_r && B_type == rocblas_datatype_f32_r &&
D_type == rocblas_datatype_f32_r) || } else {
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && // epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
D_type == rocblas_datatype_f32_r) || // Add bias_ptr to D_temp and store in pre_gelu_out, and apply GELU to the pre_gelu_output and then store in D
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
D_type == rocblas_datatype_f16_r) || // gemm_ex(A=W, B=X, transA=T)
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && batch_size = n;
D_type == rocblas_datatype_bf16_r) || input_dim = k;
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && output_dim = m;
D_type == rocblas_datatype_f8_r) || DType output_dtype = get_transformer_engine_dtype(D_type);
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && DType bias_dtype = get_transformer_engine_dtype(bias_type);
D_type == rocblas_datatype_bf8_r) || DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
D_type == rocblas_datatype_f32_r) || output_dtype, OType,
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
D_type == rocblas_datatype_f16_r) || gelu_dtype, GType,
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
D_type == rocblas_datatype_bf16_r) || bias_dtype, BType,
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && detail::add_bias_gelu_kernelLauncher<OType, GType, BType>(
D_type == rocblas_datatype_f8_r) || reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && reinterpret_cast<GType*>(pre_gelu_out),
D_type == rocblas_datatype_bf8_r) || reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && reinterpret_cast<const float*>(D_scale), batch_size, output_dim,
D_type == rocblas_datatype_f32_r) || stream););););
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && }
D_type == rocblas_datatype_f16_r) || } else if (bias) {
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && if (grad) {
D_type == rocblas_datatype_bf16_r) || // grad output is always input B
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && // epilogue = CUBLASLT_EPILOGUE_BGRADB;
D_type == rocblas_datatype_f8_r) || // Apply bias gradient to matrix B and store in bias_ptr, reduce along the k dimension, output bias length is n
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && // As B is transposed, is of shape (n, k) in column major, and is of shape (k, n) in row major.
D_type == rocblas_datatype_bf8_r), // bias gradient vector length is n. So it will be reduced along axis 0 in row major.
"Only the following combinations of data types are enabled now!\n\ // The backward pass calculate the bias gradient along with dW = dY^T * X
1. input: fp32, output: fp32.\n\ // gemm_ex(A=X, B = dY, transB=T)
2. input: fp16, output: fp16.\n\ batch_size = k;
3. input: bf16, output: bf16.\n\ input_dim = m;
4. input: fp8/bf8, output: fp8/bf8, fp16/bf16, fp32"); output_dim = n;
void* bias_tmp;
//If D is not fp32, then we need a temp buffer for GEMM result before applying epilogues. Otherwise, we can apply epilogues in-place. if (bias_type != rocblas_datatype_f32_r) {
// with bias or gelu, allocate fp32 D_temp if the output is not fp32 if (!stream_order_alloc) {
// with input fp8/bf8 (use_fp8) and bf16 output, need a fp32 D_temp, as rocblas does not support this case (fp8/bf8 input fp16/fp32 output is supported) NVTE_CHECK_CUDA(hipMalloc(&bias_tmp, sizeof(float) * output_dim));
// with use_fp8 true and fp8/bf8 output, need fp32 D_temp to support amax and scale operation } else {
void* D_temp; NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * output_dim, stream));
if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) || }
(use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r || } else {
D_type == rocblas_datatype_bf8_r))) { bias_tmp = bias_ptr;
if (!stream_order_alloc) { }
NVTE_CHECK_CUDA(hipMalloc(&D_temp, sizeof(float) * m * n));
} else { DType input_dtype = get_transformer_engine_dtype(B_type);
NVTE_CHECK_CUDA(hipMallocAsync(&D_temp, sizeof(float) * m * n, stream)); DType output_dtype = get_transformer_engine_dtype(D_type);
} DType bias_dtype = get_transformer_engine_dtype(bias_type);
} else { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
D_temp = D; input_dtype, IType,
} detail::bias_gradient_kernelLauncher<IType>(
reinterpret_cast<const IType*>(B), reinterpret_cast<float*>(bias_tmp), batch_size,
// When Ti=To=fp16 and there is no bias or gelu, D_temp points to D and we would like it to be fp16 output_dim, stream_order_alloc, stream););
rocblas_datatype D_temp_type = rocblas_datatype_f32_r; if (bias_type != rocblas_datatype_f32_r) {
if (!(bias || gelu) && (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r && TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
D_type == rocblas_datatype_f16_r)) { bias_dtype, BType,
D_temp_type = rocblas_datatype_f16_r; detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
} reinterpret_cast<BType*>(bias_ptr),
// When Ti=To=bf16 and there is no bias or gelu, D_temp points to D and we would like it to be bf16 output_dim, stream););
if (!(bias || gelu) && (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r && if (!stream_order_alloc) {
D_type == rocblas_datatype_bf16_r)) { NVTE_CHECK_CUDA(hipFree(bias_tmp));
D_temp_type = rocblas_datatype_bf16_r; } else {
} NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
// When Ti in fp8 or bf8, To=fp16, there is no bias or gelu, D_temp points to D and we would like it to be fp16, as rocblas support this case. }
if ((!(bias || gelu)) && (use_fp8 && D_type == rocblas_datatype_f16_r)) { }
D_temp_type = rocblas_datatype_f16_r; if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
} TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType,
if (accumulate && (D_temp != D || D_temp_type != D_type)) { detail::identity_kernelLauncher<float, OType>(reinterpret_cast<const float*>(D_temp),
DType output_dtype = get_transformer_engine_dtype(D_type); reinterpret_cast<OType*>(D),
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( input_dim * output_dim, stream););
output_dtype, OType, }
//D_temp allocated only with fp32 } else {
detail::identity_kernelLauncher<OType, float>( // epilogue = CUBLASLT_EPILOGUE_BIAS;
reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(D_temp), m * n, stream);); // Broadcast bias and add it to D_temp and store in D. The bias vector length is m
} // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T)
// D = alpha * (A * B) + beta * C batch_size = n;
if (use_fp8) { input_dim = k;
rocblas_computetype computeType = rocblas_compute_type_f32; output_dim = m;
NVTE_CHECK_ROCBLAS(rocblas_gemm_ex3(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B, DType output_dtype = get_transformer_engine_dtype(D_type);
B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp, DType bias_dtype = get_transformer_engine_dtype(bias_type);
D_temp_type, ldd, computeType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, 0)); output_dtype, OType,
} else { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
rocblas_datatype computeType = rocblas_datatype_f32_r; bias_dtype, BType,
uint32_t flags = rocblas_gemm_flags_none; detail::add_bias_kernelLauncher<OType, BType>(
if ((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r) && grad) { reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
flags = rocblas_gemm_flags_fp16_alt_impl; reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
} reinterpret_cast<const float*>(D_scale), batch_size, output_dim, stream);););
NVTE_CHECK_ROCBLAS(rocblas_gemm_ex(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B, }
B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp, } else if (gelu) {
D_temp_type, ldd, computeType, if (grad) {
rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, flags)); // epilogue = CUBLASLT_EPILOGUE_DGELU;
} // Take input from pre_gelu_out and apply GELU gradients to D_temp and store result in D
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
int batch_size, input_dim, output_dim; // gemm_ex(A=W, B=dY)
if (bias && gelu) { batch_size = n;
if (grad) { input_dim = m;
// epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; output_dim = k;
// Apply GELU gradient to D_temp and store in D DType output_dtype = get_transformer_engine_dtype(D_type);
// Apply bias gradient to D (D is already the result of GELU gradient) and store in bias_ptr; DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
// This case is NN TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major output_dtype, OType,
// The bias vector length is m. So it will be reduced along axis 0 in row major TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
// (TODO): The cublasLt doc is not very clear wrt the bias gradient here. gelu_dtype, GType,
// It does not explicitly say that it goes through GELU gradient first. We will need to detail::gelu_backward_kernelLauncher<OType, GType>(
// confirm in the future. As of now, my implementation for the bias gradient takes reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
// the GELU gradient result in lower precision (D). It might be better to take the GELU reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
// gradient result in fp32 but as it requires some kernel changes I would only do that } else {
// once we confirm that this is the right form of the epilogue. // epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
// This is for linear1 -> gelu -> linear2 // Store (quantized) D_temp in pre_gelu_out, and apply GELU to D_temp then store in D
// compute dX = dY * W for linear2 // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=dY) // gemm_ex(A=W, B=X, transA=T)
batch_size = n; batch_size = n;
input_dim = input_dim = k;
m; // input dimension of the second linear layer is the output dimension of the first linear layer output_dim = m;
output_dim = k;
DType output_dtype = get_transformer_engine_dtype(D_type); DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( gelu_dtype, GType,
output_dtype, OType, detail::identity_kernelLauncher<float, GType>(reinterpret_cast<const float*>(D_temp),
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( reinterpret_cast<GType*>(pre_gelu_out),
gelu_dtype, GType, batch_size * output_dim, stream););
detail::gelu_backward_kernelLauncher<OType, GType>( DType output_dtype = get_transformer_engine_dtype(D_type);
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D), TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream););); output_dtype, OType,
detail::gelu_forward_kernelLauncher<OType>(
void* bias_tmp; reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
if (bias_type != rocblas_datatype_f32_r) { reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), batch_size,
if (!stream_order_alloc) { output_dim, stream););
NVTE_CHECK_CUDA(hipMalloc( }
&bias_tmp, } else { // No epilogue - !(bias || gelu)
sizeof(float) * input_dim)); // The bias gradient is for the first linear layer if (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
} else { D_type == rocblas_datatype_bf8_r)) {
NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * input_dim, stream)); DType output_dtype = get_transformer_engine_dtype(D_type);
} TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
} else { output_dtype, OType,
bias_tmp = bias_ptr; detail::identity_output_kernelLauncher<OType>(
} reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), m * n,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( stream););
output_dtype, OType, }
detail::bias_gradient_kernelLauncher<OType>( }
reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(bias_tmp), batch_size,
input_dim, stream_order_alloc, stream);); if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
(use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
if (bias_type != rocblas_datatype_f32_r) { D_type == rocblas_datatype_bf8_r))) {
DType bias_dtype = get_transformer_engine_dtype(bias_type); if (!stream_order_alloc) {
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( NVTE_CHECK_CUDA(hipFree(D_temp));
bias_dtype, BType, } else {
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp), NVTE_CHECK_CUDA(hipFreeAsync(D_temp, stream));
reinterpret_cast<BType*>(bias_ptr), }
input_dim, stream);); }
if (!stream_order_alloc) { }
NVTE_CHECK_CUDA(hipFree(bias_tmp));
} else { #endif //USE_ROCBLAS
NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
} void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
} const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, bool transa, bool transb, bool grad, void* workspace,
} else { size_t workspaceSize, bool accumulate, bool use_split_accumulator,
// epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; int math_sm_count, int m_split, int n_split, bool gemm_producer,
// Add bias_ptr to D_temp and store in pre_gelu_out, and apply GELU to the pre_gelu_output and then store in D const Tensor* inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0,
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major bool nvte_use_rocblas = 0, int compute_stream_offset = -1) {
// gemm_ex(A=W, B=X, transA=T) /*If no backend is specified with env variable use HIPBLASLT unless it is disabled
batch_size = n; If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
input_dim = k; Otherwise use ROCBLAS
output_dim = m; */
DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type); bool use_hipblaslt = (std::getenv("NVTE_USE_HIPBLASLT") != nullptr) || nvte_use_hipblaslt;
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); bool use_rocblas = (std::getenv("NVTE_USE_ROCBLAS") != nullptr) || nvte_use_rocblas;
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType, #if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( #error GEMM backend is not specified
gelu_dtype, GType, #elif !defined(USE_HIPBLASLT)
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( if (use_hipblaslt) {
bias_dtype, BType, use_hipblaslt = false;
detail::add_bias_gelu_kernelLauncher<OType, GType, BType>( use_rocblas = true;
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D), std::cout << "[NOTICE] hipBLASLt is not enabled, NVTE_USE_HIPBLASLT env is ignored\n";
reinterpret_cast<GType*>(pre_gelu_out), }
reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax), #elif !defined(USE_ROCBLAS)
reinterpret_cast<const float*>(D_scale), batch_size, output_dim, if (use_rocblas) {
stream);););); use_rocblas = false;
} use_hipblaslt = true;
} else if (bias) { std::cout << "[NOTICE] rocBLAS is not enabled, NVTE_USE_ROCBLAS env is ignored\n";
if (grad) { }
// grad output is always input B #else
// epilogue = CUBLASLT_EPILOGUE_BGRADB; if (use_hipblaslt && use_rocblas) {
// Apply bias gradient to matrix B and store in bias_ptr, reduce along the k dimension, output bias length is n use_rocblas = false;
// As B is transposed, is of shape (n, k) in column major, and is of shape (k, n) in row major. use_hipblaslt = true;
// bias gradient vector length is n. So it will be reduced along axis 0 in row major. // std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n";
// The backward pass calculate the bias gradient along with dW = dY^T * X } else if (!use_hipblaslt && !use_rocblas) {
// gemm_ex(A=X, B = dY, transB=T) use_rocblas = false;
batch_size = k; use_hipblaslt = true;
input_dim = m; // std::cout << "[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used\n";
output_dim = n; }
void* bias_tmp; #endif
if (bias_type != rocblas_datatype_f32_r) {
if (!stream_order_alloc) { #ifdef USE_HIPBLASLT
NVTE_CHECK_CUDA(hipMalloc(&bias_tmp, sizeof(float) * output_dim)); if (use_hipblaslt || !use_rocblas) {
} else { // Check compute_stream_offset valid.
NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * output_dim, stream)); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
}
} else { hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
bias_tmp = bias_ptr; (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
} grad, workspace, workspaceSize, accumulate, use_split_accumulator, math_sm_count,
m_split, n_split, gemm_producer, inputCounter, stream);
DType input_dtype = get_transformer_engine_dtype(B_type);
DType output_dtype = get_transformer_engine_dtype(D_type); return;
DType bias_dtype = get_transformer_engine_dtype(bias_type); }
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( #endif
input_dtype, IType,
detail::bias_gradient_kernelLauncher<IType>( #ifdef USE_ROCBLAS
reinterpret_cast<const IType*>(B), reinterpret_cast<float*>(bias_tmp), batch_size, if (use_rocblas) {
output_dim, stream_order_alloc, stream);); rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
if (bias_type != rocblas_datatype_f32_r) { (transa) ? rocblas_operation_transpose : rocblas_operation_none,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( (transb) ? rocblas_operation_transpose : rocblas_operation_none, grad, workspace,
bias_dtype, BType, workspaceSize, accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp), gemm_producer, inputCounter, stream);
reinterpret_cast<BType*>(bias_ptr), }
output_dim, stream);); #endif
if (!stream_order_alloc) { }
NVTE_CHECK_CUDA(hipFree(bias_tmp));
} else { } //namespace transformer_engine
NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
}
}
if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType,
detail::identity_kernelLauncher<float, OType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<OType*>(D),
input_dim * output_dim, stream););
}
} else {
// epilogue = CUBLASLT_EPILOGUE_BIAS;
// Broadcast bias and add it to D_temp and store in D. The bias vector length is m
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T)
batch_size = n;
input_dim = k;
output_dim = m;
DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
bias_dtype, BType,
detail::add_bias_kernelLauncher<OType, BType>(
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
reinterpret_cast<const float*>(D_scale), batch_size, output_dim, stream);););
}
} else if (gelu) {
if (grad) {
// epilogue = CUBLASLT_EPILOGUE_DGELU;
// Take input from pre_gelu_out and apply GELU gradients to D_temp and store result in D
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=dY)
batch_size = n;
input_dim = m;
output_dim = k;
DType output_dtype = get_transformer_engine_dtype(D_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
gelu_dtype, GType,
detail::gelu_backward_kernelLauncher<OType, GType>(
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
} else {
// epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
// Store (quantized) D_temp in pre_gelu_out, and apply GELU to D_temp then store in D
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T)
batch_size = n;
input_dim = k;
output_dim = m;
DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
gelu_dtype, GType,
detail::identity_kernelLauncher<float, GType>(reinterpret_cast<const float*>(D_temp),
reinterpret_cast<GType*>(pre_gelu_out),
batch_size * output_dim, stream););
DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType,
detail::gelu_forward_kernelLauncher<OType>(
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), batch_size,
output_dim, stream););
}
} else { // No epilogue - !(bias || gelu)
if (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
D_type == rocblas_datatype_bf8_r)) {
DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType,
detail::identity_output_kernelLauncher<OType>(
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), m * n,
stream););
}
}
if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
(use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
D_type == rocblas_datatype_bf8_r))) {
if (!stream_order_alloc) {
NVTE_CHECK_CUDA(hipFree(D_temp));
} else {
NVTE_CHECK_CUDA(hipFreeAsync(D_temp, stream));
}
}
}
#endif //USE_ROCBLAS
void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, bool transa, bool transb, bool grad, void* workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor* inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0,
bool nvte_use_rocblas = 0, int compute_stream_offset = -1) {
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
Otherwise use ROCBLAS
*/
bool use_hipblaslt = (std::getenv("NVTE_USE_HIPBLASLT") != nullptr) || nvte_use_hipblaslt;
bool use_rocblas = (std::getenv("NVTE_USE_ROCBLAS") != nullptr) || nvte_use_rocblas;
#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
#error GEMM backend is not specified
#elif !defined(USE_HIPBLASLT)
if (use_hipblaslt) {
use_hipblaslt = false;
use_rocblas = true;
std::cout << "[NOTICE] hipBLASLt is not enabled, NVTE_USE_HIPBLASLT env is ignored\n";
}
#elif !defined(USE_ROCBLAS)
if (use_rocblas) {
use_rocblas = false;
use_hipblaslt = true;
std::cout << "[NOTICE] rocBLAS is not enabled, NVTE_USE_ROCBLAS env is ignored\n";
}
#else
if (use_hipblaslt && use_rocblas) {
use_rocblas = false;
use_hipblaslt = true;
// std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n";
} else if (!use_hipblaslt && !use_rocblas) {
use_rocblas = false;
use_hipblaslt = true;
// std::cout << "[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used\n";
}
#endif
#ifdef USE_HIPBLASLT
if (use_hipblaslt || !use_rocblas) {
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, workspace, workspaceSize, accumulate, use_split_accumulator, math_sm_count,
m_split, n_split, gemm_producer, inputCounter, stream, handle);
return;
}
#endif
#ifdef USE_ROCBLAS
if (use_rocblas) {
rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
(transa) ? rocblas_operation_transpose : rocblas_operation_none,
(transb) ? rocblas_operation_transpose : rocblas_operation_none, grad, workspace,
workspaceSize, accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, inputCounter, stream);
}
#endif
}
} //namespace transformer_engine
\ No newline at end of file
...@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr ...@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
// Choose between runtime-compiled or statically-compiled kernel // Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel //TODO:Using RTC may cause kernel crashes. Therefore, set use_rtc to true to avoid using RTC and resolve the kernel crash issue.
#ifdef USE_ROCM
const bool use_rtc = false;
#else
const bool use_rtc = true;
#endif
if (aligned && rtc::is_enabled() && use_rtc) { // Runtime-compiled tuned kernel
// Pick kernel config // Pick kernel config
std::vector<KernelConfig> kernel_configs; std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16); kernel_configs.reserve(16);
......
...@@ -55,7 +55,7 @@ def apply_normalization( ...@@ -55,7 +55,7 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True) normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32): if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32) and not zero_centered_gamma:
out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True) out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma return out, None, rsigma
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment