Commit acfa43b8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5.post1-opt1-wm' into 'v0.8.5.post1-opt1'

[fix]修复并行解码eagle和mlp相关单测问题

See merge request dcutoolkit/deeplearing/vllm!138
parents 9bcbaafc 08cb5d8f
......@@ -27,6 +27,8 @@ from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix
import vllm.envs as envs
os.environ["LLAMA_NN"] = "0"
# main model
MAIN_MODEL = os.path.join(models_path_prefix, "JackFram/llama-68m")
......
......@@ -59,6 +59,9 @@ PRECISION = "float16"
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -72,9 +75,9 @@ PRECISION = "float16"
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [4, 32])
@pytest.mark.parametrize("batch_size", [4, 4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
......@@ -107,6 +110,9 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -125,7 +131,7 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
},
])
@pytest.mark.parametrize("output_len", [8])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
......@@ -171,6 +177,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -182,7 +191,7 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
......@@ -224,12 +233,15 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
"speculative_config": {
"model": SPEC_MODEL,
},
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("temperature", [1.0])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
......@@ -269,7 +281,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
......@@ -282,6 +294,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -323,7 +338,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
......@@ -336,6 +351,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -392,6 +410,9 @@ def test_mlp_e2e_greedy_correctness_with_padding(
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -446,6 +467,9 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......@@ -495,6 +519,9 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......
......@@ -51,9 +51,14 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.sampler.include_gpu_probs_tensor = True
if hasattr(self.model_runner.model, "sampler"):
(self.model_runner.model.sampler.include_gpu_probs_tensor) = True
def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
if hasattr(self.model_runner.model, "sampler"):
(self.model_runner.model.sampler.should_modify_greedy_probs_inplace
) = True
@torch.inference_mode()
def sampler_output(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment