Commit 5cc8ee3e authored by zhaochao's avatar zhaochao
Browse files

[DCU] fix some bug


Signed-off-by: default avatarzhaochao <zhaochao1@sugon.com>
parent 183a88cf
...@@ -51,6 +51,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp ...@@ -51,6 +51,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
mkdir -p $TE_PATH/artifacts/tests/pytorch/test_checkpoint && python $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all --checkpoint-dir $TE_PATH/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
......
...@@ -216,9 +216,6 @@ def test_dot_product_attention( ...@@ -216,9 +216,6 @@ def test_dot_product_attention(
# FlashAttention backend # FlashAttention backend
if flash_attn_supported: if flash_attn_supported:
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION and config.head_dim_qk < config.head_dim_v:
pytest.skip("FlashAttention on ROCm does not support MLA with head_dim_qk < head_dim_v")
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, dtype,
config, config,
...@@ -263,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model): ...@@ -263,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mla = { model_configs_mla = {
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 # "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 # "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 # "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 # "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
"mla_2_1": ModelConfig( # "mla_2_1": ModelConfig(
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 # 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
), # cross, 1 # ), # cross, 1
"mla_2_2": ModelConfig( # "mla_2_2": ModelConfig(
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 # 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
), # cross, 1 # ), # cross, 1
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference # "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
} }
......
...@@ -50,8 +50,8 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend ...@@ -50,8 +50,8 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = FP8GlobalStateManager.is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0) sm_80plus = get_device_compute_capability() >= (8, 0)
...@@ -582,6 +582,12 @@ def _test_e2e_selective_recompute( ...@@ -582,6 +582,12 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
config = model_configs[model] config = model_configs[model]
...@@ -690,8 +696,14 @@ def _test_e2e_full_recompute( ...@@ -690,8 +696,14 @@ def _test_e2e_full_recompute(
def test_gpt_full_activation_recompute( def test_gpt_full_activation_recompute(
dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
...@@ -1277,10 +1289,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ...@@ -1277,10 +1289,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation = True fuse_wgrad_accumulation = True
fp8_model_params = False fp8_model_params = False
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.") pytest.skip("FP8 requires sequence length to be divisible by 16.")
...@@ -1793,6 +1809,12 @@ def test_grouped_linear_accuracy( ...@@ -1793,6 +1809,12 @@ def test_grouped_linear_accuracy(
parallel_mode=None, parallel_mode=None,
): ):
fp8 = recipe is not None fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
...@@ -1837,8 +1859,9 @@ def test_grouped_linear_accuracy( ...@@ -1837,8 +1859,9 @@ def test_grouped_linear_accuracy(
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}") weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "1"
outputs_ref = _test_grouped_linear_accuracy( outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, sequential_linear,
num_gemms, num_gemms,
...@@ -1861,7 +1884,8 @@ def test_grouped_linear_accuracy( ...@@ -1861,7 +1884,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
delay_wgrad_compute, delay_wgrad_compute,
) )
if IS_HIP_EXTENSION:
os.environ["NVTE_FORCE_ROCM_GEMM"] = "0"
# Shoule be bit-wise match # Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
...@@ -1893,6 +1917,12 @@ def test_grouped_linear_accuracy_save_original_input( ...@@ -1893,6 +1917,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2099,8 +2129,14 @@ def test_padding_grouped_linear_accuracy( ...@@ -2099,8 +2129,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params, fp8_model_params,
parallel_mode=None, parallel_mode=None,
): ):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2172,6 +2208,13 @@ def test_padding_grouped_linear_accuracy_save_original_input( ...@@ -2172,6 +2208,13 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed(): if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input") pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8: if config.max_seqlen_q % 16 != 0 and fp8:
...@@ -2383,8 +2426,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ...@@ -2383,8 +2426,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
def test_gpt_fp8_parameters(dtype, bs, model, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model] config = model_configs[model]
......
...@@ -890,14 +890,6 @@ class FlashAttention(torch.nn.Module): ...@@ -890,14 +890,6 @@ class FlashAttention(torch.nn.Module):
elif q_format == "thd": elif q_format == "thd":
# thd -> t(hd) # thd -> t(hd)
output = output.reshape(output.shape[0], -1) output = output.reshape(output.shape[0], -1)
# Handle output shape when V head dim differs from Q/K head dim
if value_layer.shape[-1] != query_layer.shape[-1]:
v_dim = value_layer.shape[-1]
num_heads = query_layer.shape[-2]
out_shape_heads = output.shape[:-1] + (num_heads, query_layer.shape[-1])
output = output.view(out_shape_heads)[..., :v_dim]
output = output.reshape(output.shape[:-2] + (num_heads * v_dim,))
return output.contiguous() return output.contiguous()
......
...@@ -53,7 +53,7 @@ def apply_normalization( ...@@ -53,7 +53,7 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True) normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32): if enable_lightop and (ln_bias is None) and normalization == "RMSNorm" and output_quantizer is None and (output_dtype is torch.bfloat16 or output_dtype is torch.float16 or output_dtype is torch.float32) and not zero_centered_gamma:
out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True) out, rsigma = rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
return out, None, rsigma return out, None, rsigma
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment