Commit 2148040f authored by zhaochao's avatar zhaochao
Browse files

[DCU] Skip some tests in test_sanity.py


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>
parent c9eab7e7
...@@ -46,7 +46,7 @@ from utils import ModelConfig ...@@ -46,7 +46,7 @@ from utils import ModelConfig
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run. # Record initial RNG state from script run.
...@@ -81,6 +81,7 @@ def is_fp8_supported(config: ModelConfig): ...@@ -81,6 +81,7 @@ def is_fp8_supported(config: ModelConfig):
return True return True
model_configs = { model_configs = {
"126m": ModelConfig(2, 2048, 12, 64, num_layers=12), "126m": ModelConfig(2, 2048, 12, 64, num_layers=12),
"small": ModelConfig(2, 32, 2, 32, num_layers=2), "small": ModelConfig(2, 32, 2, 32, num_layers=2),
...@@ -369,6 +370,12 @@ def test_sanity_layernorm_linear( ...@@ -369,6 +370,12 @@ def test_sanity_layernorm_linear(
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
...@@ -395,7 +402,13 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba ...@@ -395,7 +402,13 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -425,7 +438,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -425,7 +438,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
num_tokens = bs * config.max_seqlen_q num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
...@@ -464,7 +483,13 @@ def test_sanity_grouped_linear( ...@@ -464,7 +483,13 @@ def test_sanity_grouped_linear(
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
...@@ -514,7 +539,13 @@ def test_sanity_layernorm_mlp( ...@@ -514,7 +539,13 @@ def test_sanity_layernorm_mlp(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -556,7 +587,13 @@ def test_sanity_gpt( ...@@ -556,7 +587,13 @@ def test_sanity_gpt(
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -722,7 +759,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): ...@@ -722,7 +759,13 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -752,7 +795,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): ...@@ -752,7 +795,13 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -786,7 +835,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): ...@@ -786,7 +835,13 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
...@@ -820,7 +875,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra ...@@ -820,7 +875,13 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
sigma = 0.023 sigma = 0.023
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment