Unverified Commit ee1531bc authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

[Bugfix][2/n] Fix speculative decoding CI - Fix test_ngram_e2e_greedy_correctness (#19644)

parent e13945f9
...@@ -14,10 +14,13 @@ MAIN_MODEL = "JackFram/llama-68m" ...@@ -14,10 +14,13 @@ MAIN_MODEL = "JackFram/llama-68m"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model_name": "JackFram/llama-68m",
# Verify equality when cuda graphs allowed. # Verify equality when cuda graphs allowed.
"enforce_eager": False, "enforce_eager": False,
"model_name": "JackFram/llama-68m",
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"per_test_common_llm_kwargs", "per_test_common_llm_kwargs",
...@@ -59,6 +62,9 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, ...@@ -59,6 +62,9 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", []) @pytest.mark.parametrize("per_test_common_llm_kwargs", [])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -117,6 +123,9 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, ...@@ -117,6 +123,9 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......
...@@ -17,7 +17,10 @@ from .conftest import run_equality_correctness_test ...@@ -17,7 +17,10 @@ from .conftest import run_equality_correctness_test
"model_name": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -75,6 +78,9 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, ...@@ -75,6 +78,9 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -128,6 +134,9 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, ...@@ -128,6 +134,9 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -182,6 +191,9 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, ...@@ -182,6 +191,9 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -256,8 +268,12 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, ...@@ -256,8 +268,12 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"model_name": "JackFram/llama-160m", "model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......
...@@ -494,6 +494,9 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -494,6 +494,9 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# Precision
"dtype": PRECISION,
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......
...@@ -40,6 +40,9 @@ from .conftest import run_equality_correctness_test ...@@ -40,6 +40,9 @@ from .conftest import run_equality_correctness_test
# Print spec metrics. # Print spec metrics.
"disable_log_stats": False, "disable_log_stats": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
...@@ -97,6 +100,9 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -97,6 +100,9 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Print spec metrics. # Print spec metrics.
"disable_log_stats": False, "disable_log_stats": False,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
...@@ -160,6 +166,9 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -160,6 +166,9 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
...@@ -221,6 +230,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption( ...@@ -221,6 +230,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -281,6 +293,9 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, ...@@ -281,6 +293,9 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -337,6 +352,9 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -337,6 +352,9 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# The original model is float32, keep it for numerical stability.
"dtype": "float32",
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......
...@@ -74,6 +74,7 @@ class EAGLE(nn.Module): ...@@ -74,6 +74,7 @@ class EAGLE(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.dtype = vllm_config.model_config.dtype
self.config = config self.config = config
architectures = getattr(self.config.model, "architectures", []) architectures = getattr(self.config.model, "architectures", [])
...@@ -250,7 +251,7 @@ class EAGLE(nn.Module): ...@@ -250,7 +251,7 @@ class EAGLE(nn.Module):
lm_head_weight = torch.zeros( lm_head_weight = torch.zeros(
self.lm_head.org_vocab_size, self.lm_head.org_vocab_size,
self.lm_head.embedding_dim, self.lm_head.embedding_dim,
dtype=self.config.torch_dtype, dtype=self.dtype,
) )
weight_loader = getattr(self.lm_head.weight, "weight_loader", weight_loader = getattr(self.lm_head.weight, "weight_loader",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment