Commit 67f4b1b4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-wm-1209' into 'v0.9.2-dev'

[feat]零消耗适配宽松mtp

See merge request dcutoolkit/deeplearing/vllm!290
parents 4c4cfb18 740e060f
...@@ -1645,7 +1645,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1645,7 +1645,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
else: else:
spec_result = self.propose_draft_token_ids( spec_token_ids = self.propose_draft_token_ids(
scheduler_output, scheduler_output,
valid_sampled_token_ids, valid_sampled_token_ids,
sampling_metadata, sampling_metadata,
...@@ -1655,10 +1655,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1655,10 +1655,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata, spec_decode_metadata,
attn_metadata, attn_metadata,
) )
if not envs.VLLM_REJECT_SAMPLE_OPT:
spec_token_ids = spec_result
else:
spec_token_ids, _ = spec_result
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
...@@ -1795,9 +1791,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1795,9 +1791,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if not envs.VLLM_REJECT_SAMPLE_OPT: if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result draft_token_ids = draft_result
spec_token_ids = draft_token_ids.tolist() else:
return spec_token_ids
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys()) draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
draft_token_ids, draft_probs = draft_result draft_token_ids, draft_probs = draft_result
if self.draft_probs is None: if self.draft_probs is None:
...@@ -1806,7 +1800,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -1806,7 +1800,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else: else:
self.draft_probs.update(draft_probs, draft_req_ids) self.draft_probs.update(draft_probs, draft_req_ids)
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs
return spec_token_ids return spec_token_ids
...@@ -3442,7 +3435,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3442,7 +3435,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
first_neg_one_indices = torch.argmax(mask_int, dim=1) first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1 num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_result = self.zero_propose_draft_token_ids( spec_token_ids = self.zero_propose_draft_token_ids(
scheduler_output, scheduler_output,
num_accepted_tokens_tensor, num_accepted_tokens_tensor,
sampled_token_ids, sampled_token_ids,
...@@ -3453,10 +3446,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3453,10 +3446,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_decode_metadata, spec_decode_metadata,
attn_metadata, attn_metadata,
) )
if not envs.VLLM_REJECT_SAMPLE_OPT:
spec_token_ids = spec_result
else:
spec_token_ids, _ = spec_result
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
...@@ -3618,9 +3607,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3618,9 +3607,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# self.last_draft_event.record() # self.last_draft_event.record()
if not envs.VLLM_REJECT_SAMPLE_OPT: if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result draft_token_ids = draft_result
spec_token_ids = draft_token_ids.tolist() else:
return spec_token_ids
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys()) draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
draft_token_ids, draft_probs = draft_result draft_token_ids, draft_probs = draft_result
if self.draft_probs is None: if self.draft_probs is None:
...@@ -3628,8 +3615,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3628,8 +3615,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
draft_probs, draft_req_ids) draft_probs, draft_req_ids)
else: else:
self.draft_probs.update(draft_probs, draft_req_ids) self.draft_probs.update(draft_probs, draft_req_ids)
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs
return spec_token_ids return spec_token_ids
#TODO:稳定后使用GPUModelRunnerMTP替换GPUModelRunner #TODO:稳定后使用GPUModelRunnerMTP替换GPUModelRunner
if envs.VLLM_USE_ZERO_MTP: if envs.VLLM_USE_ZERO_MTP:
......
...@@ -22,6 +22,7 @@ from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer ...@@ -22,6 +22,7 @@ from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.v1.spec_decode.utils import DraftProbs
class V1ZeroModelRunner(GPUModelRunner): class V1ZeroModelRunner(GPUModelRunner):
...@@ -228,8 +229,12 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -228,8 +229,12 @@ class V1ZeroModelRunner(GPUModelRunner):
req_idx = self.input_batch.req_id_to_index[req_id] req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids) num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_ids = None
if envs.VLLM_REJECT_SAMPLE_OPT:
spec_decode_ids = scheduler_output.scheduled_spec_decode_tokens.keys()
spec_decode_metadata = self._calc_spec_decode_metadata( spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens) num_draft_tokens, cu_num_tokens, spec_decode_ids)
logits_indices = spec_decode_metadata.logits_indices logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model # Hot-Swap lora model
...@@ -370,7 +375,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -370,7 +375,7 @@ class V1ZeroModelRunner(GPUModelRunner):
target_slot_mapping = eagle_attn_metadata.slot_mapping[ target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices] token_indices]
self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
draft_token_ids = self.drafter.propose( draft_result = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
...@@ -381,6 +386,19 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -381,6 +386,19 @@ class V1ZeroModelRunner(GPUModelRunner):
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
decoding=spec_decode_metadata is not None, decoding=spec_decode_metadata is not None,
) )
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result
else:
draft_token_ids, draft_probs = draft_result
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, draft_req_ids)
else:
self.draft_probs.update(draft_probs, draft_req_ids)
spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist() spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
self.last_draft_token_ids = draft_token_ids self.last_draft_token_ids = draft_token_ids
self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True) self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
...@@ -555,6 +573,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -555,6 +573,7 @@ class V1ZeroModelRunner(GPUModelRunner):
# logits tensor. This means any in-place operations on bonus_logits # logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor. # won't affect the original logits tensor.
assert logits is not None assert logits is not None
if not envs.VLLM_REJECT_SAMPLE_OPT:
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler( sampler_output = self.sampler(
logits=bonus_logits, logits=bonus_logits,
...@@ -574,6 +593,27 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -574,6 +593,27 @@ class V1ZeroModelRunner(GPUModelRunner):
sampling_metadata, sampling_metadata,
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else:
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
target_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.target_logits_indices]
target_logits = logits[spec_decode_metadata.target_logits_indices]
bonus_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.bonus_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
self.draft_probs.get_probs(spec_decode_metadata.spec_decode_ids),
target_logits,
target_token_ids,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
num_nans_in_logits = {} num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS: if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment