Commit 8a413453 authored by jujl1's avatar jujl1
Browse files

feat: 兼容MTP零消耗和主模型+MTP零消耗(VLLM_ZERO_OVERHEAD_ENHANCE=1)开启

parent 5208b291
...@@ -200,6 +200,7 @@ if TYPE_CHECKING: ...@@ -200,6 +200,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1298,6 +1299,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1298,6 +1299,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM": "VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")), ("true", "1")),
"VLLM_ZERO_OVERHEAD_ENHANCE":
lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -8,7 +8,7 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs ...@@ -8,7 +8,7 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm import envs
requsets_valid_token_len = {} requsets_valid_token_len = {}
def check_stop(request: Request, def check_stop(request: Request,
...@@ -83,10 +83,22 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -83,10 +83,22 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
request._all_token_ids[fix_offset] = generated_token_ids request._all_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1 requsets_valid_token_len[req_id] += 1
generated_token_ids = [generated_token_ids] generated_token_ids = [generated_token_ids]
else: elif envs.VLLM_ZERO_OVERHEAD_ENHANCE:
requsets_valid_token_len[req_id] += len(generated_token_ids) requsets_valid_token_len[req_id] += len(generated_token_ids)
request._output_token_ids[fix_offset : ] = generated_token_ids request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids request._all_token_ids[fix_offset : ] = generated_token_ids
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
request.num_computed_tokens = request.num_tokens - 1
else:
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids)
stopped = False stopped = False
...@@ -189,8 +201,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -189,8 +201,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[ generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else[] req_index] if sampled_token_ids else []
if request.num_computed_tokens == request.num_prompt_tokens: if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
request.num_computed_tokens == request.num_prompt_tokens):
generated_token_ids = generated_token_ids[:1] generated_token_ids = generated_token_ids[:1]
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
...@@ -203,8 +216,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -203,8 +216,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# tokens, where is given by: # tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_new = len(generated_token_ids) num_new = len(generated_token_ids)
if (model_runner_output.fix_req_ids and req_id in model_runner_output.fix_req_ids if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
and request.num_computed_tokens > request.num_prompt_tokens + num_new): model_runner_output.fix_req_ids and
req_id in model_runner_output.fix_req_ids and
request.num_computed_tokens > request.num_prompt_tokens + num_new):
req_idx = model_runner_output.fix_req_ids.index(req_id) req_idx = model_runner_output.fix_req_ids.index(req_id)
num_new = len(model_runner_output.fix_sampled_token_ids[req_idx]) num_new = len(model_runner_output.fix_sampled_token_ids[req_idx])
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new) num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new)
...@@ -213,7 +228,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -213,7 +228,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats = scheduler.make_spec_decoding_stats( spec_decoding_stats = scheduler.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids), num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=(num_new - 1) if generated_token_ids else 0 ) num_accepted_tokens=num_new - 1)
# NOTE(woosuk): This has to be executed after updating # NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`. # `request.num_computed_tokens`.
......
...@@ -84,7 +84,8 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -84,7 +84,8 @@ class V1ZeroModelRunner(GPUModelRunner):
num_scheduled_tokens) num_scheduled_tokens)
if self.speculative_config and self.last_sampler_host_tokens != None: if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
self.speculative_config and self.last_sampler_host_tokens != None):
self.fix_req_ids = self.last_sampled_req_ids self.fix_req_ids = self.last_sampled_req_ids
self.last_sampler_event.synchronize() # 等上一轮主模型结束 self.last_sampler_event.synchronize() # 等上一轮主模型结束
num_gen_tokens = self.last_sampler_host_tokens.shape[-1] num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
...@@ -106,8 +107,8 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -106,8 +107,8 @@ class V1ZeroModelRunner(GPUModelRunner):
# # 更新token统计数据 # # 更新token统计数据
self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx
self.input_batch.num_tokens[new_req_idx] = new_end_idx self.input_batch.num_tokens[new_req_idx] = new_end_idx
self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = self.fix_sampled_token_ids[ self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = (
req_idx] self.fix_sampled_token_ids)[req_idx]
self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx) self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx)
if req_id in self.requests: if req_id in self.requests:
req_state = self.requests[req_id] req_state = self.requests[req_id]
...@@ -299,13 +300,15 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -299,13 +300,15 @@ class V1ZeroModelRunner(GPUModelRunner):
True) True)
last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int) last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor] input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
update_req_indices = [] update_req_indices = []
input_ids_indices = [] input_ids_indices = []
token_idx = 0 token_idx = 0
if self.last_sampled_token_ids is not None: if self.last_sampled_token_ids is not None:
sampled_tokens_num = 1 if self.speculative_config else self.last_sampled_token_ids.shape[1]
for req_id in req_ids: for req_id in req_ids:
if req_id in self.last_sampled_req_ids: if req_id in self.last_sampled_req_ids:
req_idx = self.last_sampled_req_ids.index(req_id) req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
update_req_indices.append(req_idx) update_req_indices.append(req_idx)
input_ids_indices.append(token_idx) input_ids_indices.append(token_idx)
token_idx += scheduler_output.num_scheduled_tokens[req_id] token_idx += scheduler_output.num_scheduled_tokens[req_id]
...@@ -316,12 +319,14 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -316,12 +319,14 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32, input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
self.device, self.device,
True) True)
if self.speculative_config: if envs.VLLM_ZERO_OVERHEAD_ENHANCE and self.speculative_config:
fused_update_input_ids_impl(self.last_sampled_token_ids,input_ids, fused_update_input_ids_impl(self.last_sampled_token_ids,input_ids,
update_req_indices_tensor,input_ids_indices_tensor) update_req_indices_tensor,input_ids_indices_tensor)
else: else:
last_sampled_token_ids = self.last_sampled_token_ids.flatten() last_sampled_token_ids = self.last_sampled_token_ids.flatten()
input_ids[input_ids_indices_tensor] =last_sampled_token_ids[update_req_indices_tensor] for i in range(sampled_tokens_num):
input_ids[input_ids_indices_tensor + i] = (
last_sampled_token_ids)[update_req_indices_tensor + i]
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
...@@ -698,27 +703,29 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -698,27 +703,29 @@ class V1ZeroModelRunner(GPUModelRunner):
is_output_valid = False is_output_valid = False
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
if not self.speculative_config: over_head_enhance = envs.VLLM_ZERO_OVERHEAD_ENHANCE and self.speculative_config
self.fix_req_ids = self.last_sampled_req_ids if over_head_enhance:
if self.last_sampler_host_tokens is not None: # if not self.speculative_config:
self.last_sampler_event.synchronize() # self.fix_req_ids = self.last_sampled_req_ids
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist() # if self.last_sampler_host_tokens is not None:
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record: # self.last_sampler_event.synchronize()
if start_idx == -1: # self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
self.fix_sampled_token_ids[req_idx].clear() # for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
continue # if start_idx == -1:
req_id = self.fix_req_ids[req_idx] # self.fix_sampled_token_ids[req_idx].clear()
if req_id in self.input_batch.req_ids: # continue
new_req_idx = self.input_batch.req_ids.index(req_id) # req_id = self.fix_req_ids[req_idx]
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx] # if req_id in self.input_batch.req_ids:
for req_idx, req_id in enumerate(self.fix_req_ids): # new_req_idx = self.input_batch.req_ids.index(req_id)
if req_id in self.requests: # self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
req_state = self.requests[req_id] # for req_idx, req_id in enumerate(self.fix_req_ids):
token_idx = self.last_sampled_token_lens[req_idx] # if req_id in self.requests:
if token_idx == -1: # req_state = self.requests[req_id]
continue # token_idx = self.last_sampled_token_lens[req_idx]
fix_len = len(self.fix_sampled_token_ids[req_idx]) # if token_idx == -1:
req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx] # continue
# fix_len = len(self.fix_sampled_token_ids[req_idx])
# req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = None self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None self.last_sampled_token_ids = None
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True) self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
...@@ -726,15 +733,14 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -726,15 +733,14 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_ids = sampled_token_ids self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist() valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if not self.speculative_config: if not self.speculative_config:
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
fix_draft_req_ids = None fix_draft_req_ids = None
else: else:
if not over_head_enhance:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record()
if self.last_draft_host_tokens is not None: if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize() self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist() fix_draft_token_ids = self.last_draft_host_tokens.tolist()
...@@ -755,6 +761,51 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -755,6 +761,51 @@ class V1ZeroModelRunner(GPUModelRunner):
attn_metadata, attn_metadata,
) )
if not over_head_enhance:
if self.speculative_config:
self.spec_sampler_event.synchronize()
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids_cpu,
self.input_batch.vocab_size,
)
self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None
is_output_valid = True
else:
# No spec decode tokens.
self.fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize()
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(self.fix_req_ids):
if req_id in self.requests:
req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx]
if token_idx == -1:
continue
fix_len = len(self.fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler # Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back. # doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends # NOTE(woosuk): As an exception, when using PP, the scheduler sends
...@@ -765,12 +816,13 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -765,12 +816,13 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_lens = [] self.last_sampled_token_lens = []
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[req_idx] req_id = self.input_batch.req_ids[req_idx]
cache_output_len = -1
self.last_sampled_req_ids.append(req_id) self.last_sampled_req_ids.append(req_id)
cache_output_len = -1
if not sampled_ids: if not sampled_ids:
self.last_sampled_token_lens.append(-1) self.last_sampled_token_lens.append(-1)
self.token_ids_cpu_fix_record.append([req_idx, -1, -1]) self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
continue continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx] start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids) end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, ( assert end_idx <= self.max_model_len, (
...@@ -783,7 +835,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -783,7 +835,7 @@ class V1ZeroModelRunner(GPUModelRunner):
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx]) self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
if not self.speculative_config and req_id in self.requests: if not over_head_enhance and req_id in self.requests:
req_state = self.requests[req_id] req_state = self.requests[req_id]
cache_output_len = len(req_state.output_token_ids) cache_output_len = len(req_state.output_token_ids)
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment