Commit 215f33b0 authored by 王敏's avatar 王敏
Browse files

[fix]修复单测test_medusa_correctness.py中的错误

parent 0ec3aa4b
......@@ -36,7 +36,7 @@ SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
MAX_SPEC_TOKENS = 5
# precision
PRECISION = "float32"
PRECISION = "float16"
@pytest.mark.parametrize(
......
......@@ -40,9 +40,10 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
self.medusa_buffers = None
if hasattr(self.model_runner.model, 'medusa_choices'):
self.medusa_choices = self.model_runner.model.medusa_choices
self.medusa_buffers = self.generate_medusa_buffers(
self.medusa_choices, device=self.device
)
if self.medusa_choices is not None:
self.medusa_buffers = self.generate_medusa_buffers(
self.medusa_choices, device=self.device
)
if self.medusa_buffers is None:
self._proposer = Top1Proposer(
......@@ -342,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Move the tensors in the dictionary to the specified device
medusa_buffers = {
k: v.clone().to(device)
k: (v.clone().to(device) if k != "tree_position_ids" else v.clone())
if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device)
for k, v in medusa_buffers.items()
......
......@@ -512,7 +512,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
if seq_group_metadata.tree_position_ids is not None:
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.contiguous().cpu().tolist()
inter_data.input_positions[seq_idx] = seq_group_metadata.tree_position_ids.contiguous().tolist()
inter_data.tree_attn_masks[seq_idx] = seq_group_metadata.tree_attn_masks
inter_data.query_lens[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment