Commit 1edc9e13 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.8' into release_v2.8

parents 5e7dd67e 3a040217
......@@ -49,6 +49,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
mkdir -p $TE_PATH/artifacts/tests/pytorch/test_checkpoint && python $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all --checkpoint-dir $TE_PATH/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
......
......@@ -340,22 +340,23 @@ def test_dpa_softmax(dtype, model_configs, model):
model_configs_mla = {
# test: ModelConfig(b, sq, hq, dqk)
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),
"mla_2_1": ModelConfig(
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
),
"mla_2_2": ModelConfig(
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
),
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# # test: ModelConfig(b, sq, hq, dqk)
# "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
# "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
# "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
# "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
# "mla_2_1": ModelConfig(
# 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
# ), # cross, 1
# "mla_2_2": ModelConfig(
# 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
# ), # cross, 1
# "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
# "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
# "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
# "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
}
......
......@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import (
_emulate_linear,
......@@ -47,7 +48,6 @@ TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
tp_size = WORLD_SIZE
......@@ -72,6 +72,16 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"):
if IS_HIP_EXTENSION:
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
name=name,
bias=False,
parallel_mode=parallel_mode,
tp_group=(tp_group or NCCL_WORLD if parallel_mode else None),
)
else:
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
......@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
......@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
if IS_HIP_EXTENSION:
if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]:
return # Output type 32 (FP32) does not support int8 simulation.
if WORLD_RANK == 0:
fake_quant_fp8_create_config(
fprop_inp,
......@@ -667,6 +679,10 @@ if __name__ == "__main__":
random.seed(SEED)
_init_distributed()
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
test_log_expert_parallel()
for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]:
......@@ -676,6 +692,11 @@ if __name__ == "__main__":
for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
# test_disable_fp8_gemms
_run_test_with_combinations(
test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
......@@ -690,7 +711,10 @@ if __name__ == "__main__":
extra_args=["column", "row"],
sample_size=20,
)
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
_run_test_with_combinations(
test_per_tensor_scaling,
all_boolean,
......
......@@ -733,7 +733,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
def test_linear():
"""Run linear layer tests with various configurations."""
kwargs_list = [
base_kwargs_list = [
{},
{"bias": False},
{"init_method": _constant},
......@@ -743,7 +743,15 @@ def test_linear():
{"delay_wgrad_compute": True},
{"save_original_input": True},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
......@@ -913,7 +921,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
def test_layernorm_linear():
kwargs_list = [
base_kwargs_list = [
{},
{"bias": False},
{"init_method": _constant},
......@@ -924,7 +932,15 @@ def test_layernorm_linear():
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
for sequence_parallel in [False, True]:
......@@ -1019,7 +1035,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
def test_layernorm_mlp():
kwargs_list = [
base_kwargs_list = [
{},
{"init_method": _constant},
{"output_layer_init_method": _constant},
......@@ -1033,7 +1049,15 @@ def test_layernorm_mlp():
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
for set_parallel_mode in [True]:
for sequence_parallel in [False, True]:
......@@ -1108,7 +1132,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
def test_transformer_layer():
kwargs_list = [
base_kwargs_list = [
{},
{"num_gqa_groups": 4},
{"init_method": _constant},
......@@ -1128,6 +1152,15 @@ def test_transformer_layer():
{"fuse_qkv_params": True},
{"activation": "relu"},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
for sequence_parallel in [False, True]:
......
......@@ -9,7 +9,8 @@ from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine as te
"""
Distributed numerics tests
......@@ -66,4 +67,15 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_fp8_block_scaling)
if quantization == "nvfp4" and not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", "None")
os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8)
_run_test(quantization)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
if ori_int8_sim_fp8 is None or ori_int8_sim_fp8 == "None":
os.environ["NVTE_INT8_SIM_FP8"] = "0"
else:
del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8)
......@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION:
from functools import cache
# Check if FP8 is supported.
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Reset RNG states.
reset_rng_states()
......@@ -367,6 +367,12 @@ def test_make_graphed_callables(
)
if fp8_params:
pytest.skip("NVFP4 params not supported")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
......
......@@ -6,7 +6,6 @@ import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
......
......@@ -2111,7 +2111,8 @@ class TestFusedOps:
quantized_weight: bool = False,
) -> None:
"""Forward GEMM + scale + add"""
if IS_HIP_EXTENSION and scale != 1:
pytest.skip("alpha must be 1.0 for hip")
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
......@@ -2496,7 +2497,8 @@ class TestFusedOps:
quantized_weight: bool = False,
) -> None:
"""Backward dgrad GEMM + scale"""
if IS_HIP_EXTENSION and scale != 1:
pytest.skip("alpha must be 1.0 for hip")
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
......
......@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
......@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
config = model_configs[model]
......@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation = True
fp8_model_params = False
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
......@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
use_cutlass=False,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
......@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear,
num_gemms,
......@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
for o, o_ref in zip(outputs, outputs_ref):
if use_cutlass:
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
......@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
......@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
......@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
......@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......
......@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.onnx_extensions import te_translation_table
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt
......@@ -65,7 +67,6 @@ if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
......@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
],
outputs=[PyCustomOpDef.dt_uint8],
)
def trt_fp8_quantize(t, scale_inv):
def trt_fp8_quantize(t, scale):
"""FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale_inv).cuda(),
scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
......@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
],
outputs=[PyCustomOpDef.dt_float],
)
def trt_fp8_dequantize(t, scale_inv):
def trt_fp8_dequantize(t, scale):
"""FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale_inv).cuda(),
scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
......@@ -469,16 +470,22 @@ def _test_export_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(fp8_recipe=fp8_recipe, precision=precision)
@pytest.mark.parametrize("use_bias", [True, False])
def test_export_linear_use_bias(seed_default_rng, use_bias):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(use_bias=use_bias)
@pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(return_bias=return_bias)
......@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm(normalization=normalization)
......@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
fname,
inp,
model,
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
atol=1e-3,
is_fp8=fp8_recipe is not None,
te_outputs=te_outputs,
)
......@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_linear_return_ln_out(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(return_layernorm_output=True)
def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(normalization=normalization)
def test_export_layernorm_linear_no_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(use_bias=False)
def test_export_layernorm_linear_return_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(return_bias=True)
......@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(return_layernorm_output=True)
def test_export_layernorm_mlp_return_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(return_bias=True)
def test_export_layernorm_mlp_no_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(use_bias=False)
def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(normalization=normalization)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(activation=activation)
......@@ -731,6 +764,8 @@ def test_export_core_attention(
use_mask: bool,
attn_mask_type: str,
):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
# Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
......@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_multihead_attention_recipe(fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)
def test_export_multihead_attention_no_mask():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(use_mask=False)
def test_export_multihead_attention_no_input_layernorm():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(input_layernorm=False)
def test_export_multihead_attention_cross_attn():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(attention_type="cross")
def test_export_multihead_attention_unfused_qkv_params():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(fuse_qkv_params=False)
......@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)
def test_export_transformer_layer_no_mask():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(use_mask=False)
def test_export_transformer_layer_output_layernorm():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(output_layernorm=True)
def test_export_transformer_layer_unfused_qkv_params():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(fuse_qkv_params=False)
def test_export_transformer_layer_zero_centered_gamma():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(zero_centered_gamma=True)
@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(activation=activation)
......@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths.
"""
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
# Layer configuration
hidden_size = 64
sequence_length = 128
......@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_trt_integration(fp8_recipe: recipe.Recipe):
if IS_HIP_EXTENSION:
pytest.skip("TRT is not supported for HIP")
model = te.TransformerLayer(
hidden_size=128,
ffn_hidden_size=128,
num_attention_heads=4,
).eval()
if type(fp8_recipe) == recipe.Float8CurrentScaling:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model = te.LayerNormMLP(128, 128)
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
......
......@@ -46,7 +46,7 @@ from utils import ModelConfig
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run.
......@@ -388,7 +388,13 @@ def test_sanity_layernorm_linear(
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
......@@ -450,7 +456,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
......@@ -543,7 +555,13 @@ def test_sanity_layernorm_mlp(
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
......@@ -587,7 +605,13 @@ def test_sanity_gpt(
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
......@@ -759,7 +783,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
......@@ -791,7 +821,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
......@@ -827,7 +863,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
......@@ -863,7 +905,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
if fp8_recipe.nvfp4() and dtype == torch.float16:
pytest.skip("FP16 output for NVFP4 not supported")
......
......@@ -1561,15 +1561,6 @@ void nvte_cublas_batchgemm_tensorwise_int8(const NVTETensor A, const NVTETensor
NVTE_ERROR("TT layout not allowed.");
}
hipblasLtHandle_t handle = nullptr;
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[0];
NVTE_ERROR("Remove nvte_cublas_batchgemm_tensorwise_int8 for now.");
}
......
......@@ -465,106 +465,6 @@ transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t)
namespace {
static class HandlePool {
public:
hipblasLtHandle_t get(int device_id) {
std::lock_guard<std::mutex> lock(mt);
if (pool.empty()) {
int device_count = 0;
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
pool.resize(device_count);
return nullptr;
}
if (!pool[device_id].empty()) {
hipblasLtHandle_t h = pool[device_id].front();
pool[device_id].pop_front();
return h;
}
return nullptr;
}
hipblasLtHandle_t obtain(int device_id) {
hipblasLtHandle_t h = get(device_id);
if (h == nullptr) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h));
}
return h;
}
void store(const std::vector<hipblasLtHandle_t>& handles) {
std::lock_guard<std::mutex> lock(mt);
if (pool.empty()) {
std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl;
}
for (unsigned int i = 0; i < pool.size(); i++) {
if (handles[i] != nullptr) {
pool[i].push_front(handles[i]);
}
}
}
~HandlePool() {
#if DESTROY_HIPBLASLT_HANDLES_POOL
std::lock_guard<std::mutex> lock(mt);
for (auto& hlist : pool) {
for (auto& h : hlist) {
hipblasLtDestroy(h);
}
}
pool.clear();
#endif
}
inline size_t get_size() const { return pool.size(); }
private:
std::mutex mt;
using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>;
// Order of destructors between thread_local and global is not actually guaranteed
// As a simple w/a make pool storage "leaky"
// Just do not destruct it and do not destroy hipbladLt handles
// Let OS deal with it on application exit
#if DESTROY_HIPBLASLT_HANDLES_POOL
Pool pool;
#else
Pool& pool = *new Pool();
#endif
} handle_pool;
thread_local static class HandleCache {
public:
hipblasLtHandle_t get(int device_id) const { return d.empty() ? nullptr : d[device_id]; }
hipblasLtHandle_t obtain(int device_id) {
hipblasLtHandle_t h = get(device_id);
if (h) {
return h;
}
h = handle_pool.obtain(device_id);
set(device_id, h);
return h;
}
void set(int device_id, hipblasLtHandle_t h) {
if (d.empty()) {
d.resize(handle_pool.get_size());
}
d[device_id] = h;
}
~HandleCache() {
if (!d.empty()) {
handle_pool.store(d);
}
}
private:
std::vector<hipblasLtHandle_t> d;
} cached_handles;
class csv_helper {
public:
struct start {};
......@@ -987,18 +887,12 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
} //namespace
/* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
NVTE_CHECK(hipblaslt_handles != nullptr);
for (int i = 0; i < compute_num_streams; i++) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i]));
}
static inline void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(handle));
}
using hipBlasLtHandleManager = detail::HandleManager<hipblasLtHandle_t, CreateHipBlasLtHandle>;
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
using namespace transformer_engine;
switch (t) {
......@@ -1018,8 +912,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
bool grad, void* workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream,
hipblasLtHandle_t handle) {
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) {
void* A = inputA->data.dptr;
void* A_scale_inverse = inputA->scale_inv.dptr;
float* A_scale_inverse_float = (float*)(inputA->scale_inv.dptr);
......@@ -1064,12 +957,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int device_id;
NVTE_CHECK_CUDA(hipGetDevice(&device_id));
if (handle == nullptr) {
handle = cached_handles.get(device_id);
if (handle == nullptr) {
handle = cached_handles.obtain(device_id);
}
}
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
hipblasLtMatmulDesc_t operationDesc = nullptr;
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
......@@ -1352,82 +1240,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
class userArgsManager {
public:
userArgsManager() {}
~userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = userArgs_map_.find(device_id);
if (device_it != userArgs_map_.end()) {
return device_it->second;
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
userArgs_map_[device_id] = userArgs;
return userArgs;
struct HipBlasLtUserArgsDeleter {
void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept {
hipFree(ptr);
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
class d_userArgsManager {
public:
d_userArgsManager() {}
using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>;
~d_userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : d_userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) {
hipblaslt_ext::UserArguments* raw_ptr = nullptr;
if (host) {
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
}
return HipBlasLtUserArgsPtr(raw_ptr);
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = d_userArgs_map_.find(device_id);
if (device_it != d_userArgs_map_.end()) {
return device_it->second;
inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) {
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache;
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache;
auto size_it = user_args_cache.find(size);
if (size_it != user_args_cache.end()) {
return size_it->second.get();
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* d_userArgs;
NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
d_userArgs_map_[device_id] = d_userArgs;
return d_userArgs;
else
{
HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host);
hipblaslt_ext::UserArguments* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args);
return raw_ptr;
}
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
d_userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
// Define a static userArgs manager
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
......@@ -1438,23 +1285,13 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
int device_id;
hipGetDevice(&device_id);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
......@@ -1972,20 +1809,10 @@ void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, workspace, workspaceSize, accumulate, use_split_accumulator, math_sm_count,
m_split, n_split, gemm_producer, inputCounter, stream, handle);
m_split, n_split, gemm_producer, inputCounter, stream);
return;
}
......
......@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
// Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
//TODO:Using RTC may cause kernel crashes. Therefore, set use_rtc to true to avoid using RTC and resolve the kernel crash issue.
#ifdef USE_ROCM
const bool use_rtc = false;
#else
const bool use_rtc = true;
#endif
if (aligned && rtc::is_enabled() && use_rtc) { // Runtime-compiled tuned kernel
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
......
......@@ -55,7 +55,7 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32):
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32) and not zero_centered_gamma:
out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment