Commit 072e3895 authored by gaoqiong's avatar gaoqiong
Browse files

Merge branch 'v0.9.2-dev-main+mtp-zero' into 'v0.9.2-dev'

V0.9.2 dev main+mtp zero

See merge request dcutoolkit/deeplearing/vllm!325
parents cd42bf87 76e22965
......@@ -200,6 +200,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1298,6 +1299,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
"VLLM_ZERO_OVERHEAD_ENHANCE":
lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -4,6 +4,8 @@ from enum import Enum
import os
import torch
import vllm.envs as envs
import triton
import triton.language as tl
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
......@@ -69,3 +71,60 @@ def zero_overhead_stream(target_device):
if target_device not in alloc_stream.keys():
alloc_stream[target_device] = torch.cuda.Stream(device=target_device)
return alloc_stream[target_device]
@triton.jit
def fused_last_valid_scatter_kernel(
last_ids_ptr, # [B, T]
input_ids_ptr, # [N]
update_req_ptr, # [U]
input_pos_ptr, # [U]
stride0,
stride1,
T,
BLOCK_T: tl.constexpr,
):
pid = tl.program_id(0)
# indices
req_idx = tl.load(update_req_ptr + pid)
input_pos = tl.load(input_pos_ptr + pid)
# load row
offs = tl.arange(0, BLOCK_T)
mask = offs < T
row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1
vals = tl.load(row_ptr, mask=mask, other=-1)
idx = tl.where(vals != -1, offs, -1)
last_idx = tl.max(idx, axis=0)
# load last token
last_val = tl.load(
last_ids_ptr + req_idx * stride0 + last_idx * stride1,
mask=last_idx >= 0,
other=0,
)
# scatter
tl.store(input_ids_ptr + input_pos, last_val)
def fused_update_input_ids_impl(
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
):
B, T = last_sampled_token_ids.shape
U = update_req_indices.numel()
BLOCK_T = 1024
assert T <= BLOCK_T
grid = (U,)
fused_last_valid_scatter_kernel[grid](
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
last_sampled_token_ids.stride(0),
last_sampled_token_ids.stride(1),
T,
BLOCK_T=BLOCK_T,
)
\ No newline at end of file
import torch
from collections import defaultdict
from typing import Optional
......@@ -9,18 +8,19 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm import envs
requsets_valid_token_len = {}
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None,
use_valid_token_len:bool = False) -> bool:
use_valid_token_len: bool = False,
last_token_offset: Optional[int] = 0) -> bool:
if use_valid_token_len:
if request.request_id not in requsets_valid_token_len:
requsets_valid_token_len[request.request_id] = 0
return False
valid_output_len = requsets_valid_token_len[request.request_id]
valid_output_len = requsets_valid_token_len[request.request_id] - last_token_offset
else:
valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len
......@@ -72,6 +72,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
continue
request = scheduler.requests[req_id]
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
if not generated_token_ids:
continue
if req_id not in requsets_valid_token_len:
requsets_valid_token_len[req_id] = 0
valid_output_len = requsets_valid_token_len[req_id]
......@@ -81,6 +83,13 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
request._all_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1
generated_token_ids = [generated_token_ids]
elif envs.VLLM_ZERO_OVERHEAD_ENHANCE:
requsets_valid_token_len[req_id] += len(generated_token_ids)
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
request.num_computed_tokens = request.num_tokens - 1
else:
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
......@@ -100,7 +109,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for num_new, output_token_id in enumerate(new_token_ids, 1):
stopped = check_stop(request, scheduler.max_model_len, True)
stopped = check_stop(request, scheduler.max_model_len, use_valid_token_len=True,
last_token_offset=len(new_token_ids) - num_new)
if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
......@@ -110,7 +120,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if pooler_outputs:
pooler_output = pooler_outputs[req_idx]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output, True)
pooler_output=pooler_output, use_valid_token_len=True)
if stopped:
kv_transfer_params = scheduler._free_request(request)
......@@ -192,6 +202,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
request.num_computed_tokens == request.num_prompt_tokens):
generated_token_ids = generated_token_ids[:1]
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
......@@ -202,13 +215,20 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
num_new = len(generated_token_ids)
if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
model_runner_output.fix_req_ids and
req_id in model_runner_output.fix_req_ids and
request.num_computed_tokens > request.num_prompt_tokens + num_new):
req_idx = model_runner_output.fix_req_ids.index(req_id)
num_new = len(model_runner_output.fix_sampled_token_ids[req_idx])
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new)
request.num_computed_tokens -= num_tokens_rejected
spec_decoding_stats = scheduler.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
num_accepted_tokens=num_new - 1)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
......@@ -231,7 +251,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if model_runner_output.is_output_valid:
stopped = check_stop(request, scheduler.max_model_len,
False)
use_valid_token_len=False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
......@@ -243,7 +263,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output,
False)
use_valid_token_len=False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
......
from typing import Any, Optional, Union
import torch
import numpy as np
......@@ -23,7 +22,7 @@ from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.v1.spec_decode.utils import DraftProbs
from vllm.zero_overhead.utils import fused_update_input_ids_impl
class V1ZeroModelRunner(GPUModelRunner):
def __init__(self, vllm_config, device):
......@@ -39,6 +38,9 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_draft_event = torch.cuda.Event(enable_timing=False)
self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
self.spec_scheduler_max_num_tokens = 0
self.fix_req_ids = None
self.fix_sampled_token_ids = None
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
self)
......@@ -81,6 +83,37 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
self.speculative_config and self.last_sampler_host_tokens != None):
self.fix_req_ids = self.last_sampled_req_ids
self.last_sampler_event.synchronize() # 等上一轮主模型结束
num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
if num_gen_tokens == 1:
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
else:
self.fix_sampled_token_ids = self.rejection_sampler.parse_output(
self.last_sampler_host_tokens,
self.input_batch.vocab_size,
)
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx])
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
new_end_idx = start_idx + num_accepted_tokens
# # 更新token统计数据
self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx
self.input_batch.num_tokens[new_req_idx] = new_end_idx
self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = (
self.fix_sampled_token_ids)[req_idx]
self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx)
if req_id in self.requests:
req_state = self.requests[req_id]
req_state.output_token_ids.extend(self.fix_sampled_token_ids[req_idx])
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
......@@ -272,7 +305,7 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices = []
token_idx = 0
if self.last_sampled_token_ids is not None:
sampled_tokens_num = self.last_sampled_token_ids.shape[1]
sampled_tokens_num = 1 if self.speculative_config else self.last_sampled_token_ids.shape[1]
for req_id in req_ids:
if req_id in self.last_sampled_req_ids:
req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
......@@ -286,9 +319,14 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
self.device,
True)
if envs.VLLM_ZERO_OVERHEAD_ENHANCE and self.speculative_config:
fused_update_input_ids_impl(self.last_sampled_token_ids,input_ids,
update_req_indices_tensor,input_ids_indices_tensor)
else:
last_sampled_token_ids = self.last_sampled_token_ids.flatten()
for i in range(sampled_tokens_num):
input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i]
input_ids[input_ids_indices_tensor + i] = (
last_sampled_token_ids)[update_req_indices_tensor + i]
def propose_draft_token_ids(
self,
......@@ -660,19 +698,48 @@ class V1ZeroModelRunner(GPUModelRunner):
scheduler_output,
)
fix_req_ids = None
fix_sampled_token_ids = None
fix_draft_token_ids = None
fix_draft_req_ids = self.last_sampled_req_ids
is_output_valid = False
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
over_head_enhance = (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
self.speculative_config is not None)
if over_head_enhance:
# if not self.speculative_config:
# self.fix_req_ids = self.last_sampled_req_ids
# if self.last_sampler_host_tokens is not None:
# self.last_sampler_event.synchronize()
# self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
# for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
# if start_idx == -1:
# self.fix_sampled_token_ids[req_idx].clear()
# continue
# req_id = self.fix_req_ids[req_idx]
# if req_id in self.input_batch.req_ids:
# new_req_idx = self.input_batch.req_ids.index(req_id)
# self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
# for req_idx, req_id in enumerate(self.fix_req_ids):
# if req_id in self.requests:
# req_state = self.requests[req_id]
# token_idx = self.last_sampled_token_lens[req_idx]
# if token_idx == -1:
# continue
# fix_len = len(self.fix_sampled_token_ids[req_idx])
# req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
if not over_head_enhance:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record()
if self.last_draft_host_tokens is not None:
......@@ -695,8 +762,10 @@ class V1ZeroModelRunner(GPUModelRunner):
attn_metadata,
)
if not over_head_enhance:
if self.speculative_config:
self.spec_sampler_event.synchronize()
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
else:
......@@ -710,25 +779,25 @@ class V1ZeroModelRunner(GPUModelRunner):
is_output_valid = True
else:
# No spec decode tokens.
fix_req_ids = self.last_sampled_req_ids
self.fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize()
fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
req_id = fix_req_ids[req_idx]
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(fix_req_ids):
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(self.fix_req_ids):
if req_id in self.requests:
req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx]
if token_idx == -1:
continue
fix_len = len(fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = fix_sampled_token_ids[req_idx]
fix_len = len(self.fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
......@@ -767,13 +836,12 @@ class V1ZeroModelRunner(GPUModelRunner):
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
if req_id in self.requests:
if not over_head_enhance and req_id in self.requests:
req_state = self.requests[req_id]
cache_output_len = len(req_state.output_token_ids)
req_state.output_token_ids.extend(sampled_ids)
self.last_sampled_token_lens.append(cache_output_len)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
......@@ -791,10 +859,10 @@ class V1ZeroModelRunner(GPUModelRunner):
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
fix_req_ids = fix_req_ids,
fix_sampled_token_ids = fix_sampled_token_ids,
fix_draft_tokens_ids = fix_draft_token_ids,
fix_draft_req_ids = fix_draft_req_ids,
fix_req_ids=self.fix_req_ids,
fix_sampled_token_ids=self.fix_sampled_token_ids,
fix_draft_tokens_ids=fix_draft_token_ids,
fix_draft_req_ids=fix_draft_req_ids,
is_output_valid=is_output_valid
)
return model_output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment