Unverified Commit 06be8588 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

[Bugfix] Fix the speculative decoding test by setting the target dtype (#19633)

parent d1e34cc9
......@@ -57,6 +57,9 @@ from .conftest import (get_output_from_llm_generator,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
......@@ -139,6 +142,9 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
# Print spec metrics.
"disable_log_stats": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
......@@ -216,6 +222,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
# Print spec metrics.
"disable_log_stats": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
......@@ -279,6 +288,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
......@@ -464,6 +476,8 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
......@@ -523,6 +537,8 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
......@@ -589,6 +605,8 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -655,6 +673,8 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -706,6 +726,8 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -763,6 +785,8 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment