Commit 8a413453 authored by jujl1's avatar jujl1
Browse files

feat: 兼容MTP零消耗和主模型+MTP零消耗(VLLM_ZERO_OVERHEAD_ENHANCE=1)开启

parent 5208b291
......@@ -200,6 +200,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1298,6 +1299,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
"VLLM_ZERO_OVERHEAD_ENHANCE":
lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -8,7 +8,7 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm import envs
requsets_valid_token_len = {}
def check_stop(request: Request,
......@@ -83,10 +83,22 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
request._all_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1
generated_token_ids = [generated_token_ids]
else:
elif envs.VLLM_ZERO_OVERHEAD_ENHANCE:
requsets_valid_token_len[req_id] += len(generated_token_ids)
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
request.num_computed_tokens = request.num_tokens - 1
else:
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids)
stopped = False
......@@ -189,8 +201,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else[]
if request.num_computed_tokens == request.num_prompt_tokens:
req_index] if sampled_token_ids else []
if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
request.num_computed_tokens == request.num_prompt_tokens):
generated_token_ids = generated_token_ids[:1]
scheduled_spec_token_ids = (
......@@ -203,8 +216,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_new = len(generated_token_ids)
if (model_runner_output.fix_req_ids and req_id in model_runner_output.fix_req_ids
and request.num_computed_tokens > request.num_prompt_tokens + num_new):
if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
model_runner_output.fix_req_ids and
req_id in model_runner_output.fix_req_ids and
request.num_computed_tokens > request.num_prompt_tokens + num_new):
req_idx = model_runner_output.fix_req_ids.index(req_id)
num_new = len(model_runner_output.fix_sampled_token_ids[req_idx])
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new)
......@@ -213,7 +228,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats = scheduler.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=(num_new - 1) if generated_token_ids else 0 )
num_accepted_tokens=num_new - 1)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
......
......@@ -84,7 +84,8 @@ class V1ZeroModelRunner(GPUModelRunner):
num_scheduled_tokens)
if self.speculative_config and self.last_sampler_host_tokens != None:
if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
self.speculative_config and self.last_sampler_host_tokens != None):
self.fix_req_ids = self.last_sampled_req_ids
self.last_sampler_event.synchronize() # 等上一轮主模型结束
num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
......@@ -106,8 +107,8 @@ class V1ZeroModelRunner(GPUModelRunner):
# # 更新token统计数据
self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx
self.input_batch.num_tokens[new_req_idx] = new_end_idx
self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = self.fix_sampled_token_ids[
req_idx]
self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = (
self.fix_sampled_token_ids)[req_idx]
self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx)
if req_id in self.requests:
req_state = self.requests[req_id]
......@@ -299,13 +300,15 @@ class V1ZeroModelRunner(GPUModelRunner):
True)
last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
update_req_indices = []
input_ids_indices = []
token_idx = 0
if self.last_sampled_token_ids is not None:
sampled_tokens_num = 1 if self.speculative_config else self.last_sampled_token_ids.shape[1]
for req_id in req_ids:
if req_id in self.last_sampled_req_ids:
req_idx = self.last_sampled_req_ids.index(req_id)
req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
update_req_indices.append(req_idx)
input_ids_indices.append(token_idx)
token_idx += scheduler_output.num_scheduled_tokens[req_id]
......@@ -316,12 +319,14 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
self.device,
True)
if self.speculative_config:
if envs.VLLM_ZERO_OVERHEAD_ENHANCE and self.speculative_config:
fused_update_input_ids_impl(self.last_sampled_token_ids,input_ids,
update_req_indices_tensor,input_ids_indices_tensor)
else:
last_sampled_token_ids = self.last_sampled_token_ids.flatten()
input_ids[input_ids_indices_tensor] =last_sampled_token_ids[update_req_indices_tensor]
for i in range(sampled_tokens_num):
input_ids[input_ids_indices_tensor + i] = (
last_sampled_token_ids)[update_req_indices_tensor + i]
def propose_draft_token_ids(
self,
......@@ -698,27 +703,29 @@ class V1ZeroModelRunner(GPUModelRunner):
is_output_valid = False
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
if not self.speculative_config:
self.fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens is not None:
self.last_sampler_event.synchronize()
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
self.fix_sampled_token_ids[req_idx].clear()
continue
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(self.fix_req_ids):
if req_id in self.requests:
req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx]
if token_idx == -1:
continue
fix_len = len(self.fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
over_head_enhance = envs.VLLM_ZERO_OVERHEAD_ENHANCE and self.speculative_config
if over_head_enhance:
# if not self.speculative_config:
# self.fix_req_ids = self.last_sampled_req_ids
# if self.last_sampler_host_tokens is not None:
# self.last_sampler_event.synchronize()
# self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
# for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
# if start_idx == -1:
# self.fix_sampled_token_ids[req_idx].clear()
# continue
# req_id = self.fix_req_ids[req_idx]
# if req_id in self.input_batch.req_ids:
# new_req_idx = self.input_batch.req_ids.index(req_id)
# self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
# for req_idx, req_id in enumerate(self.fix_req_ids):
# if req_id in self.requests:
# req_state = self.requests[req_id]
# token_idx = self.last_sampled_token_lens[req_idx]
# if token_idx == -1:
# continue
# fix_len = len(self.fix_sampled_token_ids[req_idx])
# req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
......@@ -726,15 +733,14 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
if not over_head_enhance:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record()
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
......@@ -755,6 +761,51 @@ class V1ZeroModelRunner(GPUModelRunner):
attn_metadata,
)
if not over_head_enhance:
if self.speculative_config:
self.spec_sampler_event.synchronize()
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids_cpu,
self.input_batch.vocab_size,
)
self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None
is_output_valid = True
else:
# No spec decode tokens.
self.fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize()
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(self.fix_req_ids):
if req_id in self.requests:
req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx]
if token_idx == -1:
continue
fix_len = len(self.fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
......@@ -765,12 +816,13 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_lens = []
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[req_idx]
cache_output_len = -1
self.last_sampled_req_ids.append(req_id)
cache_output_len = -1
if not sampled_ids:
self.last_sampled_token_lens.append(-1)
self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
......@@ -783,7 +835,7 @@ class V1ZeroModelRunner(GPUModelRunner):
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
if not self.speculative_config and req_id in self.requests:
if not over_head_enhance and req_id in self.requests:
req_state = self.requests[req_id]
cache_output_len = len(req_state.output_token_ids)
req_state.output_token_ids.extend(sampled_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment