Commit 2148040f authored by zhaochao's avatar zhaochao
Browse files

[DCU] Skip some tests in test_sanity.py


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>
parent c9eab7e7
......@@ -46,7 +46,7 @@ from utils import ModelConfig
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run.
......@@ -81,6 +81,7 @@ def is_fp8_supported(config: ModelConfig):
return True
model_configs = {
"126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
"small": ModelConfig(2, 32, 2, 32, num_layers=2),
......@@ -369,6 +370,12 @@ def test_sanity_layernorm_linear(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
sigma = 0.023
init_method = init_method_normal(sigma)
......@@ -395,7 +402,13 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -425,7 +438,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
......@@ -464,7 +483,13 @@ def test_sanity_grouped_linear(
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
......@@ -514,7 +539,13 @@ def test_sanity_layernorm_mlp(
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -556,7 +587,13 @@ def test_sanity_gpt(
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -722,7 +759,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -752,7 +795,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -786,7 +835,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
......@@ -820,7 +875,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
config = model_configs[model]
if fp8_recipe is not None:
if not is_fp8_supported(config):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
sigma = 0.023
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment