Commit 91670b05 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.8' into 'main'

[DCU] Skip some tests in test_sanity.py

See merge request dcutoolkit/deeplearing/TransformerEngine!61
parents e3780e3a 3a040217
...@@ -48,6 +48,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes ...@@ -48,6 +48,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_tes
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
mkdir -p $TE_PATH/artifacts/tests/pytorch/test_checkpoint && python $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all --checkpoint-dir $TE_PATH/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
......
...@@ -260,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -260,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mla = { model_configs_mla = {
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 # "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 # "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 # "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 # "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
"mla_2_1": ModelConfig( # "mla_2_1": ModelConfig(
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 # 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
), # cross, 1 # ), # cross, 1
"mla_2_2": ModelConfig( # "mla_2_2": ModelConfig(
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 # 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
), # cross, 1 # ), # cross, 1
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference # "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
} }
......
...@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION: ...@@ -28,9 +28,9 @@ if IS_HIP_EXTENSION:
from functools import cache from functools import cache
# Check if FP8 is supported. # Check if FP8 is supported.
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Reset RNG states. # Reset RNG states.
reset_rng_states() reset_rng_states()
...@@ -310,6 +310,12 @@ def test_make_graphed_callables( ...@@ -310,6 +310,12 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.") pytest.skip("FP8 needed for FP8 parameters.")
if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Run model with different CUDA graph settings. # Run model with different CUDA graph settings.
model_config = model_configs[model_config] model_config = model_configs[model_config]
......
...@@ -6,7 +6,6 @@ import pytest ...@@ -6,7 +6,6 @@ import pytest
import torch import torch
import transformer_engine as te import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8) from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
......
...@@ -2084,7 +2084,8 @@ class TestFusedOps: ...@@ -2084,7 +2084,8 @@ class TestFusedOps:
quantized_weight: bool = False, quantized_weight: bool = False,
) -> None: ) -> None:
"""Forward GEMM + scale + add""" """Forward GEMM + scale + add"""
if IS_HIP_EXTENSION and scale != 1:
pytest.skip("alpha must be 1.0 for hip")
# Make input and weight shapes consistent # Make input and weight shapes consistent
out_features, in_features = weight_shape out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features] in_shape = list(in_shape)[:-1] + [in_features]
...@@ -2469,7 +2470,8 @@ class TestFusedOps: ...@@ -2469,7 +2470,8 @@ class TestFusedOps:
quantized_weight: bool = False, quantized_weight: bool = False,
) -> None: ) -> None:
"""Backward dgrad GEMM + scale""" """Backward dgrad GEMM + scale"""
if IS_HIP_EXTENSION and scale != 1:
pytest.skip("alpha must be 1.0 for hip")
# Make input and weight shapes consistent # Make input and weight shapes consistent
out_features, in_features = weight_shape out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features] in_shape = list(in_shape)[:-1] + [in_features]
......
...@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend ...@@ -56,7 +56,7 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
...@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute( ...@@ -606,6 +606,13 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
config = model_configs[model] config = model_configs[model]
...@@ -714,8 +721,15 @@ def _test_e2e_full_recompute( ...@@ -714,8 +721,15 @@ def _test_e2e_full_recompute(
def test_gpt_full_activation_recompute( def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1301,9 +1315,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation = True fuse_wgrad_accumulation = True
fp8_model_params = False fp8_model_params = False
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy( ...@@ -1818,6 +1837,12 @@ def test_grouped_linear_accuracy(
use_cutlass=False, use_cutlass=False,
): ):
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
...@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy( ...@@ -1863,7 +1888,8 @@ def test_grouped_linear_accuracy(
weight_i = getattr(grouped_linear, f"weight{i}") weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
outputs_ref = _test_grouped_linear_accuracy( outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, sequential_linear,
num_gemms, num_gemms,
...@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy( ...@@ -1886,7 +1912,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
delay_wgrad_compute, delay_wgrad_compute,
) )
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
for o, o_ref in zip(outputs, outputs_ref): for o, o_ref in zip(outputs, outputs_ref):
if use_cutlass: if use_cutlass:
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
...@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -1956,6 +1983,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy( ...@@ -2162,8 +2195,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params, fp8_model_params,
parallel_mode=None, parallel_mode=None,
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2235,6 +2274,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2446,8 +2491,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
......
...@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op ...@@ -33,7 +33,9 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table from transformer_engine.pytorch.onnx_extensions import te_translation_table
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt import tensorrt as trt
...@@ -65,7 +67,6 @@ if mxfp8_available: ...@@ -65,7 +67,6 @@ if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling()) fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available: if fp8_available:
fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None) fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
...@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"] ...@@ -82,11 +83,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
], ],
outputs=[PyCustomOpDef.dt_uint8], outputs=[PyCustomOpDef.dt_uint8],
) )
def trt_fp8_quantize(t, scale_inv): def trt_fp8_quantize(t, scale):
"""FP8 quantization extension for ONNX Runtime.""" """FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda() x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer( q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale_inv).cuda(), scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(), amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
...@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv): ...@@ -102,11 +103,11 @@ def trt_fp8_quantize(t, scale_inv):
], ],
outputs=[PyCustomOpDef.dt_float], outputs=[PyCustomOpDef.dt_float],
) )
def trt_fp8_dequantize(t, scale_inv): def trt_fp8_dequantize(t, scale):
"""FP8 dequantization extension for ONNX Runtime.""" """FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda() x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer( q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale_inv).cuda(), scale=1 / torch.from_numpy(scale).cuda(),
amax=torch.zeros([1]).cuda(), amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
...@@ -469,16 +470,22 @@ def _test_export_linear( ...@@ -469,16 +470,22 @@ def _test_export_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision): def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(fp8_recipe=fp8_recipe, precision=precision) _test_export_linear(fp8_recipe=fp8_recipe, precision=precision)
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_export_linear_use_bias(seed_default_rng, use_bias): def test_export_linear_use_bias(seed_default_rng, use_bias):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(use_bias=use_bias) _test_export_linear(use_bias=use_bias)
@pytest.mark.parametrize("return_bias", [True, False]) @pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias): def test_export_linear_return_bias(seed_default_rng, return_bias):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_linear(return_bias=return_bias) _test_export_linear(return_bias=return_bias)
...@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng): ...@@ -540,6 +547,8 @@ def test_export_layernorm_zero_centered_gamma(seed_default_rng):
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_normalization(seed_default_rng, normalization): def test_export_layernorm_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm(normalization=normalization) _test_export_layernorm(normalization=normalization)
...@@ -594,9 +603,7 @@ def _test_export_layernorm_linear( ...@@ -594,9 +603,7 @@ def _test_export_layernorm_linear(
fname, fname,
inp, inp,
model, model,
# For current scaling we use Float8Quantizer in tests + amax computed by hand, atol=1e-3,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
is_fp8=fp8_recipe is not None, is_fp8=fp8_recipe is not None,
te_outputs=te_outputs, te_outputs=te_outputs,
) )
...@@ -605,27 +612,39 @@ def _test_export_layernorm_linear( ...@@ -605,27 +612,39 @@ def _test_export_layernorm_linear(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision): def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision) _test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_linear_return_ln_out(seed_default_rng): def test_export_layernorm_linear_return_ln_out(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(return_layernorm_output=True) _test_export_layernorm_linear(return_layernorm_output=True)
def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng): def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(zero_centered_gamma=True) _test_export_layernorm_linear(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:]) @pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_linear_normalization(seed_default_rng, normalization): def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(normalization=normalization) _test_export_layernorm_linear(normalization=normalization)
def test_export_layernorm_linear_no_bias(seed_default_rng): def test_export_layernorm_linear_no_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(use_bias=False) _test_export_layernorm_linear(use_bias=False)
def test_export_layernorm_linear_return_bias(seed_default_rng): def test_export_layernorm_linear_return_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_linear(return_bias=True) _test_export_layernorm_linear(return_bias=True)
...@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp( ...@@ -684,32 +703,46 @@ def _test_export_layernorm_mlp(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision): def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision) _test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)
def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng): def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(return_layernorm_output=True) _test_export_layernorm_mlp(return_layernorm_output=True)
def test_export_layernorm_mlp_return_bias(seed_default_rng): def test_export_layernorm_mlp_return_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(return_bias=True) _test_export_layernorm_mlp(return_bias=True)
def test_export_layernorm_mlp_no_bias(seed_default_rng): def test_export_layernorm_mlp_no_bias(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(use_bias=False) _test_export_layernorm_mlp(use_bias=False)
def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng): def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(zero_centered_gamma=True) _test_export_layernorm_mlp(zero_centered_gamma=True)
@pytest.mark.parametrize("normalization", all_normalizations[1:]) @pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization): def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(normalization=normalization) _test_export_layernorm_mlp(normalization=normalization)
@pytest.mark.parametrize("activation", supported_activations[1:]) @pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation): def test_export_layernorm_mlp_activation(seed_default_rng, activation):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_layernorm_mlp(activation=activation) _test_export_layernorm_mlp(activation=activation)
...@@ -731,6 +764,8 @@ def test_export_core_attention( ...@@ -731,6 +764,8 @@ def test_export_core_attention(
use_mask: bool, use_mask: bool,
attn_mask_type: str, attn_mask_type: str,
): ):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
...@@ -932,22 +967,32 @@ def _test_export_multihead_attention( ...@@ -932,22 +967,32 @@ def _test_export_multihead_attention(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_multihead_attention_recipe(fp8_recipe, precision): def test_export_multihead_attention_recipe(fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision) _test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)
def test_export_multihead_attention_no_mask(): def test_export_multihead_attention_no_mask():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(use_mask=False) _test_export_multihead_attention(use_mask=False)
def test_export_multihead_attention_no_input_layernorm(): def test_export_multihead_attention_no_input_layernorm():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(input_layernorm=False) _test_export_multihead_attention(input_layernorm=False)
def test_export_multihead_attention_cross_attn(): def test_export_multihead_attention_cross_attn():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(attention_type="cross") _test_export_multihead_attention(attention_type="cross")
def test_export_multihead_attention_unfused_qkv_params(): def test_export_multihead_attention_unfused_qkv_params():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_multihead_attention(fuse_qkv_params=False) _test_export_multihead_attention(fuse_qkv_params=False)
...@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer( ...@@ -1023,27 +1068,39 @@ def _test_export_transformer_layer(
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision): def test_export_transformer_layer_recipe(fp8_recipe, precision):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision) _test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)
def test_export_transformer_layer_no_mask(): def test_export_transformer_layer_no_mask():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(use_mask=False) _test_export_transformer_layer(use_mask=False)
def test_export_transformer_layer_output_layernorm(): def test_export_transformer_layer_output_layernorm():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(output_layernorm=True) _test_export_transformer_layer(output_layernorm=True)
def test_export_transformer_layer_unfused_qkv_params(): def test_export_transformer_layer_unfused_qkv_params():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(fuse_qkv_params=False) _test_export_transformer_layer(fuse_qkv_params=False)
def test_export_transformer_layer_zero_centered_gamma(): def test_export_transformer_layer_zero_centered_gamma():
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(zero_centered_gamma=True) _test_export_transformer_layer(zero_centered_gamma=True)
@pytest.mark.parametrize("activation", supported_activations[1:]) @pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation): def test_export_transformer_layer_activation(activation):
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
_test_export_transformer_layer(activation=activation) _test_export_transformer_layer(activation=activation)
...@@ -1056,7 +1113,8 @@ def test_export_gpt_generation( ...@@ -1056,7 +1113,8 @@ def test_export_gpt_generation(
"""Test that the ONNX model can correctly handle inputs with different shapes and that """Test that the ONNX model can correctly handle inputs with different shapes and that
the attention mask is adjusted on-the-fly to different sequence lengths. the attention mask is adjusted on-the-fly to different sequence lengths.
""" """
if IS_HIP_EXTENSION:
pytest.skip("ONNX is not currently required in hip")
# Layer configuration # Layer configuration
hidden_size = 64 hidden_size = 64
sequence_length = 128 sequence_length = 128
...@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled): ...@@ -1147,17 +1205,14 @@ def test_export_ctx_manager(enabled):
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_trt_integration(fp8_recipe: recipe.Recipe): def test_trt_integration(fp8_recipe: recipe.Recipe):
if IS_HIP_EXTENSION:
pytest.skip("TRT is not supported for HIP")
model = te.TransformerLayer( model = te.TransformerLayer(
hidden_size=128, hidden_size=128,
ffn_hidden_size=128, ffn_hidden_size=128,
num_attention_heads=4, num_attention_heads=4,
).eval() ).eval()
if type(fp8_recipe) == recipe.Float8CurrentScaling:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model = te.LayerNormMLP(128, 128)
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
......
...@@ -46,7 +46,7 @@ from utils import ModelConfig ...@@ -46,7 +46,7 @@ from utils import ModelConfig
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run. # Record initial RNG state from script run.
...@@ -378,7 +378,13 @@ def test_sanity_layernorm_linear( ...@@ -378,7 +378,13 @@ def test_sanity_layernorm_linear(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -436,7 +442,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -436,7 +442,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens = bs * config.max_seqlen_q num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
...@@ -525,7 +537,13 @@ def test_sanity_layernorm_mlp( ...@@ -525,7 +537,13 @@ def test_sanity_layernorm_mlp(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -567,7 +585,13 @@ def test_sanity_gpt( ...@@ -567,7 +585,13 @@ def test_sanity_gpt(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -733,7 +757,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -733,7 +757,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -763,7 +793,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): ...@@ -763,7 +793,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -797,7 +833,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -797,7 +833,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -831,7 +873,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra ...@@ -831,7 +873,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
......
...@@ -53,7 +53,7 @@ def apply_normalization( ...@@ -53,7 +53,7 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True) normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32): if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32) and not zero_centered_gamma:
out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True) out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma return out, None, rsigma
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment