Commit 42a95309 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm-1127' into 'v0.9.2-dev'

解决宽松mtp request_ids报错找不到问题

See merge request dcutoolkit/deeplearing/vllm!276
parents 0c060491 6446f76c
...@@ -52,7 +52,7 @@ class DraftProbs(ABC): # type: ignore[call-arg] ...@@ -52,7 +52,7 @@ class DraftProbs(ABC): # type: ignore[call-arg]
draft_probs: torch.Tensor draft_probs: torch.Tensor
# The request id list. # The request id list.
_req_ids: list[str] _req_ids: list[str] = []
def __init__(self, draft_probs, req_ids): def __init__(self, draft_probs, req_ids):
assert len(req_ids) == len(draft_probs) assert len(req_ids) == len(draft_probs)
...@@ -64,10 +64,15 @@ class DraftProbs(ABC): # type: ignore[call-arg] ...@@ -64,10 +64,15 @@ class DraftProbs(ABC): # type: ignore[call-arg]
tmp_req_ids: list[str]): tmp_req_ids: list[str]):
diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids] diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids]
index = [self._req_ids.index(req_id) for req_id in diff_req_ids] index = [self._req_ids.index(req_id) for req_id in diff_req_ids]
self._req_ids = diff_req_ids index_tensor = async_tensor_h2d(
self.draft_probs = self.draft_probs[index] index,
dtype=torch.int32,
target_device=self.draft_probs.device,
pin_memory=True)
self.draft_probs = self.draft_probs[index_tensor]
self.draft_probs = torch.cat([self.draft_probs, draft_probs]) self.draft_probs = torch.cat([self.draft_probs, draft_probs])
self._req_ids = diff_req_ids
self._req_ids.extend(tmp_req_ids) self._req_ids.extend(tmp_req_ids)
assert len(self._req_ids) == len(self.draft_probs) assert len(self._req_ids) == len(self.draft_probs)
...@@ -76,7 +81,12 @@ class DraftProbs(ABC): # type: ignore[call-arg] ...@@ -76,7 +81,12 @@ class DraftProbs(ABC): # type: ignore[call-arg]
if new_req_ids != self._req_ids: if new_req_ids != self._req_ids:
# Batch contents changed - prune removed sequences. # Batch contents changed - prune removed sequences.
index = [self._req_ids.index(req_id) for req_id in new_req_ids] index = [self._req_ids.index(req_id) for req_id in new_req_ids]
self.draft_probs = self.draft_probs[index] index_tensor = async_tensor_h2d(
index,
dtype=torch.int32,
target_device=self.draft_probs.device,
pin_memory=True)
self.draft_probs = self.draft_probs[index_tensor]
self._req_ids = new_req_ids self._req_ids = new_req_ids
def get_probs(self, req_ids: list[str]): def get_probs(self, req_ids: list[str]):
......
...@@ -1531,6 +1531,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1531,6 +1531,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else: else:
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
sampler_output = self.sampler( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
...@@ -1643,12 +1645,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1643,12 +1645,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if not envs.VLLM_REJECT_SAMPLE_OPT: if not envs.VLLM_REJECT_SAMPLE_OPT:
spec_token_ids = spec_result spec_token_ids = spec_result
else: else:
spec_token_ids, draft_probs = spec_result spec_token_ids, _ = spec_result
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, self.input_batch.req_ids)
else:
self.draft_probs.update(draft_probs, self.input_batch.req_ids)
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
...@@ -1787,8 +1784,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1787,8 +1784,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
draft_token_ids = draft_result draft_token_ids = draft_result
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids return spec_token_ids
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
draft_token_ids, draft_probs = draft_result draft_token_ids, draft_probs = draft_result
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, draft_req_ids)
else:
self.draft_probs.update(draft_probs, draft_req_ids)
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs return spec_token_ids, draft_probs
...@@ -3357,6 +3360,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3357,6 +3360,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else: else:
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
sampler_output = self.sampler( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
...@@ -3438,12 +3443,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3438,12 +3443,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if not envs.VLLM_REJECT_SAMPLE_OPT: if not envs.VLLM_REJECT_SAMPLE_OPT:
spec_token_ids = spec_result spec_token_ids = spec_result
else: else:
spec_token_ids, draft_probs = spec_result spec_token_ids, _ = spec_result
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, self.input_batch.req_ids)
else:
self.draft_probs.update(draft_probs, self.input_batch.req_ids)
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist() valid_sampled_token_ids = sampled_token_ids.tolist()
...@@ -3607,7 +3608,13 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3607,7 +3608,13 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids return spec_token_ids
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
draft_token_ids, draft_probs = draft_result draft_token_ids, draft_probs = draft_result
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, draft_req_ids)
else:
self.draft_probs.update(draft_probs, draft_req_ids)
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs return spec_token_ids, draft_probs
return spec_token_ids return spec_token_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment