Unverified Commit b74d888c authored by Huy Do's avatar Huy Do Committed by GitHub
Browse files

Fix more broken speculative decode tests (#17450)


Signed-off-by: default avatarHuy Do <huydhn@gmail.com>
parent 2007d4d5
...@@ -205,7 +205,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( ...@@ -205,7 +205,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"block_size": 8, "block_size": 16,
# 2 for small prompt, 256//8 for generated. # 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8, "num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8, "max_model_len": (2 + 256 // 8) * 8,
......
...@@ -267,7 +267,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, ...@@ -267,7 +267,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"block_size": 8, "block_size": 16,
# 2 for small prompt, 256//8 for generated. # 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8, "num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8, "max_model_len": (2 + 256 // 8) * 8,
...@@ -321,7 +321,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption( ...@@ -321,7 +321,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"block_size": 8, "block_size": 16,
# 2 for small prompt, 256//8 for generated. # 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8, "num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8, "max_model_len": (2 + 256 // 8) * 8,
......
...@@ -152,7 +152,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -152,7 +152,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"block_size": 8, "block_size": 16,
# 2 for small prompt, 256//8 for generated. # 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8, "num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8, "max_model_len": (2 + 256 // 8) * 8,
......
...@@ -51,9 +51,14 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): ...@@ -51,9 +51,14 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
def set_include_gpu_probs_tensor(self) -> None: def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiStepWorker # Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.sampler.include_gpu_probs_tensor = True self.model_runner.sampler.include_gpu_probs_tensor = True
if hasattr(self.model_runner.model, "sampler"):
(self.model_runner.model.sampler.include_gpu_probs_tensor) = True
def set_should_modify_greedy_probs_inplace(self) -> None: def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.sampler.should_modify_greedy_probs_inplace = True self.model_runner.sampler.should_modify_greedy_probs_inplace = True
if hasattr(self.model_runner.model, "sampler"):
(self.model_runner.model.sampler.should_modify_greedy_probs_inplace
) = True
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment