Commit 740e060f authored by 王敏's avatar 王敏
Browse files

[feat]零消耗适配宽松mtp

parent 4c4cfb18
......@@ -1645,7 +1645,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Speculative decoding is not enabled.
spec_token_ids = None
else:
spec_result = self.propose_draft_token_ids(
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
......@@ -1655,10 +1655,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata,
attn_metadata,
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
spec_token_ids = spec_result
else:
spec_token_ids, _ = spec_result
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
......@@ -1795,9 +1791,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
else:
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
draft_token_ids, draft_probs = draft_result
if self.draft_probs is None:
......@@ -1806,7 +1800,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
else:
self.draft_probs.update(draft_probs, draft_req_ids)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs
return spec_token_ids
......@@ -3442,7 +3435,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_result = self.zero_propose_draft_token_ids(
spec_token_ids = self.zero_propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
......@@ -3453,10 +3446,6 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_decode_metadata,
attn_metadata,
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
spec_token_ids = spec_result
else:
spec_token_ids, _ = spec_result
if max_gen_len == 1:
# No spec decode tokens.
......@@ -3618,9 +3607,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# self.last_draft_event.record()
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
else:
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
draft_token_ids, draft_probs = draft_result
if self.draft_probs is None:
......@@ -3628,8 +3615,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
draft_probs, draft_req_ids)
else:
self.draft_probs.update(draft_probs, draft_req_ids)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids, draft_probs
return spec_token_ids
#TODO:稳定后使用GPUModelRunnerMTP替换GPUModelRunner
if envs.VLLM_USE_ZERO_MTP:
......
......@@ -22,6 +22,7 @@ from vllm.zero_overhead.v1.eagle import V1ZeroEagleProposer
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.v1.spec_decode.utils import DraftProbs
class V1ZeroModelRunner(GPUModelRunner):
......@@ -228,8 +229,12 @@ class V1ZeroModelRunner(GPUModelRunner):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_ids = None
if envs.VLLM_REJECT_SAMPLE_OPT:
spec_decode_ids = scheduler_output.scheduled_spec_decode_tokens.keys()
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
num_draft_tokens, cu_num_tokens, spec_decode_ids)
logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model
......@@ -370,7 +375,7 @@ class V1ZeroModelRunner(GPUModelRunner):
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
self.drafter.spec_scheduler_max_num_tokens = spec_scheduler_max_num_tokens
draft_token_ids = self.drafter.propose(
draft_result = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
......@@ -381,6 +386,19 @@ class V1ZeroModelRunner(GPUModelRunner):
sampling_metadata=sampling_metadata,
decoding=spec_decode_metadata is not None,
)
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result
else:
draft_token_ids, draft_probs = draft_result
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, draft_req_ids)
else:
self.draft_probs.update(draft_probs, draft_req_ids)
spec_token_ids = np.ones(draft_token_ids.shape, dtype=int).tolist()
self.last_draft_token_ids = draft_token_ids
self.last_draft_host_tokens = draft_token_ids.to('cpu', non_blocking=True)
......@@ -555,6 +573,7 @@ class V1ZeroModelRunner(GPUModelRunner):
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
if not envs.VLLM_REJECT_SAMPLE_OPT:
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
......@@ -574,6 +593,27 @@ class V1ZeroModelRunner(GPUModelRunner):
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
else:
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
target_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.target_logits_indices]
target_logits = logits[spec_decode_metadata.target_logits_indices]
bonus_token_ids = sampler_output.sampled_token_ids[spec_decode_metadata.bonus_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
self.draft_probs.get_probs(spec_decode_metadata.spec_decode_ids),
target_logits,
target_token_ids,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment