Commit 0936ee97 authored by jujl1's avatar jujl1
Browse files

feat: 主模型+mtp提前返回

parent cd42bf87
import torch import torch
from collections import defaultdict from collections import defaultdict
from typing import Optional from typing import Optional
...@@ -89,9 +88,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -89,9 +88,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
else: else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids) requsets_valid_token_len[req_id] += len(generated_token_ids)
request._output_token_ids[:] = request._output_token_ids[:requsets_valid_token_len[req_id]]
request._all_token_ids[:] = request._all_token_ids[:request.num_prompt_tokens + requsets_valid_token_len[req_id]]
stopped = False stopped = False
new_logprobs = None new_logprobs = None
new_token_ids = generated_token_ids new_token_ids = generated_token_ids
...@@ -110,7 +110,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -110,7 +110,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if pooler_outputs: if pooler_outputs:
pooler_output = pooler_outputs[req_idx] pooler_output = pooler_outputs[req_idx]
stopped = check_stop(request, scheduler.max_model_len, stopped = check_stop(request, scheduler.max_model_len,
pooler_output, True) pooler_output=pooler_output, use_valid_token_len=True)
if stopped: if stopped:
kv_transfer_params = scheduler._free_request(request) kv_transfer_params = scheduler._free_request(request)
...@@ -191,7 +191,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -191,7 +191,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[ generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else [] req_index] if sampled_token_ids else[]
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
...@@ -202,13 +203,18 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -202,13 +203,18 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# num_computed_tokens is decreased by the number of rejected # num_computed_tokens is decreased by the number of rejected
# tokens, where is given by: # tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new = len(generated_token_ids)
len(generated_token_ids)) if (model_runner_output.fix_req_ids and req_id in model_runner_output.fix_req_ids
and request.num_computed_tokens <= request.num_prompt_tokens + num_new):
req_idx = model_runner_output.fix_req_ids.index(req_id)
num_new = len(model_runner_output.fix_sampled_token_ids[req_idx])
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new)
request.num_computed_tokens -= num_tokens_rejected request.num_computed_tokens -= num_tokens_rejected
spec_decoding_stats = scheduler.make_spec_decoding_stats( spec_decoding_stats = scheduler.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids), num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1) num_accepted_tokens=(num_new - 1) if generated_token_ids else 0 )
# NOTE(woosuk): This has to be executed after updating # NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`. # `request.num_computed_tokens`.
...@@ -231,7 +237,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -231,7 +237,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if model_runner_output.is_output_valid: if model_runner_output.is_output_valid:
stopped = check_stop(request, scheduler.max_model_len, stopped = check_stop(request, scheduler.max_model_len,
False) use_valid_token_len=False)
if stopped: if stopped:
kv_transfer_params = scheduler._free_request(request) kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed. del new_token_ids[num_new:] # Trim new tokens if needed.
...@@ -242,8 +248,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -242,8 +248,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if model_runner_output.is_output_valid: if model_runner_output.is_output_valid:
pooler_output = pooler_outputs[req_index] pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len, stopped = check_stop(request, scheduler.max_model_len,
pooler_output, pooler_output,
False) use_valid_token_len=False)
if stopped: if stopped:
kv_transfer_params = scheduler._free_request(request) kv_transfer_params = scheduler._free_request(request)
...@@ -350,10 +356,10 @@ def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]: ...@@ -350,10 +356,10 @@ def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]:
model_output = core.execute_model(scheduler_output) model_output = core.execute_model(scheduler_output)
if isinstance(model_output, ZeroV1ModelRunnerOutput): if isinstance(model_output, ZeroV1ModelRunnerOutput):
engine_core_outputs = zero_overhead_update_from_output(core.scheduler, engine_core_outputs = zero_overhead_update_from_output(core.scheduler,
scheduler_output, model_output) # type: ignore scheduler_output, model_output) # type: ignore
else: else:
engine_core_outputs = core.scheduler.update_from_output( engine_core_outputs = core.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore scheduler_output, model_output) # type: ignore
return (engine_core_outputs, return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0) scheduler_output.total_num_scheduled_tokens > 0)
\ No newline at end of file
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
import numpy as np import numpy as np
...@@ -39,6 +38,9 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -39,6 +38,9 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_draft_event = torch.cuda.Event(enable_timing=False) self.last_draft_event = torch.cuda.Event(enable_timing=False)
self.spec_sampler_event = torch.cuda.Event(enable_timing=False) self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
self.spec_scheduler_max_num_tokens = 0 self.spec_scheduler_max_num_tokens = 0
self.fix_req_ids = None
self.fix_sampled_token_ids = None
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer): if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device, self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
self) self)
...@@ -81,6 +83,39 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -81,6 +83,39 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens, arange = self._get_cumsum_and_arange( cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens) num_scheduled_tokens)
self.fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize() # 等上一轮主模型结束
if self.speculative_config: # 处理上一轮mtp
num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
if num_gen_tokens == 1:
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
else:
# Includes spec decode tokens.
self.fix_sampled_token_ids = self.rejection_sampler.parse_output(
self.last_sampler_host_tokens,
self.input_batch.vocab_size,
)
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
self.fix_sampled_token_ids[req_idx].clear()
else:
num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx])
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
new_end_idx = start_idx + num_accepted_tokens
# # 更新token统计数据
self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx
self.input_batch.num_tokens[new_req_idx] = new_end_idx
self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = self.fix_sampled_token_ids[
req_idx]
self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx)
if req_id in self.requests:
req_state = self.requests[req_id]
req_state.output_token_ids.extend(self.fix_sampled_token_ids[req_idx])
# Get positions. # Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens] positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices], np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
...@@ -267,15 +302,26 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -267,15 +302,26 @@ class V1ZeroModelRunner(GPUModelRunner):
True) True)
last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int) last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor] input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
def find_last_valid_vectorized(tensor):
"""
向量化方法找到每行最后一个非-1元素
"""
mask = tensor != -1
reversed_mask = mask.flip(dims=[1]) # 沿着列方向反转
_, col_indices = torch.max(reversed_mask.int(), dim=1)
original_col_indices = tensor.size(1) - 1 - col_indices
result = tensor[torch.arange(tensor.size(0)), original_col_indices]
all_invalid = ~mask.any(dim=1)
result[all_invalid] = -1 # 或者设置为其他默认值
return result
update_req_indices = [] update_req_indices = []
input_ids_indices = [] input_ids_indices = []
token_idx = 0 token_idx = 0
if self.last_sampled_token_ids is not None: if self.last_sampled_token_ids is not None:
sampled_tokens_num = self.last_sampled_token_ids.shape[1]
for req_id in req_ids: for req_id in req_ids:
if req_id in self.last_sampled_req_ids: if req_id in self.last_sampled_req_ids:
req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num req_idx = self.last_sampled_req_ids.index(req_id)
update_req_indices.append(req_idx) update_req_indices.append(req_idx)
input_ids_indices.append(token_idx) input_ids_indices.append(token_idx)
token_idx += scheduler_output.num_scheduled_tokens[req_id] token_idx += scheduler_output.num_scheduled_tokens[req_id]
...@@ -286,9 +332,9 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -286,9 +332,9 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32, input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
self.device, self.device,
True) True)
last_sampled_token_ids = self.last_sampled_token_ids.flatten() last_sampled_token_ids = find_last_valid_vectorized(self.last_sampled_token_ids).flatten()
for i in range(sampled_tokens_num): input_ids[input_ids_indices_tensor] = last_sampled_token_ids[update_req_indices_tensor]
input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
...@@ -660,80 +706,19 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -660,80 +706,19 @@ class V1ZeroModelRunner(GPUModelRunner):
scheduler_output, scheduler_output,
) )
fix_req_ids = None
fix_sampled_token_ids = None
fix_draft_token_ids = None fix_draft_token_ids = None
fix_draft_req_ids = self.last_sampled_req_ids fix_draft_req_ids = self.last_sampled_req_ids
is_output_valid = False is_output_valid = False
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record()
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
mask = (sampled_token_ids == -1) self.last_sampler_host_tokens = None
mask_int = mask.int() self.last_sampled_token_ids = None
first_neg_one_indices = torch.argmax(mask_int, dim=1) self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1 self.last_sampler_event.record()
spec_token_ids = self.propose_draft_token_ids( self.last_sampled_token_ids = sampled_token_ids
scheduler_output, valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
if self.speculative_config:
self.spec_sampler_event.synchronize()
if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids_cpu,
self.input_batch.vocab_size,
)
self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None
is_output_valid = True
else:
# No spec decode tokens.
fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize()
fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
req_id = fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(fix_req_ids):
if req_id in self.requests:
req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx]
if token_idx == -1:
continue
fix_len = len(fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
...@@ -767,12 +752,32 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -767,12 +752,32 @@ class V1ZeroModelRunner(GPUModelRunner):
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx]) self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
if req_id in self.requests:
req_state = self.requests[req_id] if not self.speculative_config:
cache_output_len = len(req_state.output_token_ids) # Speculative decoding is not enabled.
req_state.output_token_ids.extend(sampled_ids) spec_token_ids = None
self.last_sampled_token_lens.append(cache_output_len) fix_draft_req_ids = None
else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
...@@ -791,10 +796,10 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -791,10 +796,10 @@ class V1ZeroModelRunner(GPUModelRunner):
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits, num_nans_in_logits=num_nans_in_logits,
fix_req_ids = fix_req_ids, fix_req_ids=self.fix_req_ids,
fix_sampled_token_ids = fix_sampled_token_ids, fix_sampled_token_ids=self.fix_sampled_token_ids,
fix_draft_tokens_ids = fix_draft_token_ids, fix_draft_tokens_ids=fix_draft_token_ids,
fix_draft_req_ids = fix_draft_req_ids, fix_draft_req_ids=fix_draft_req_ids,
is_output_valid=is_output_valid is_output_valid=is_output_valid
) )
return model_output return model_output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment