Commit 0ae68da1 authored by 王敏's avatar 王敏
Browse files

[fix]修复mtp的1处笔误

parent 13130b89
...@@ -1796,13 +1796,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1796,13 +1796,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
decoding=spec_decode_metadata is not None decoding=spec_decode_metadata is not None
) )
spec_token_ids = draft_token_ids.tolist()
if not envs.VLLM_REJECT_SAMPLE_OPT: if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result draft_token_ids = draft_result
else: else:
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
draft_token_ids, draft_probs = draft_result draft_token_ids, draft_probs = draft_result
spec_token_ids = draft_token_ids.tolist()
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
if self.draft_probs is None: if self.draft_probs is None:
self.draft_probs = DraftProbs( self.draft_probs = DraftProbs(
draft_probs, draft_req_ids) draft_probs, draft_req_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment