Commit 072e3895 authored by gaoqiong's avatar gaoqiong
Browse files

Merge branch 'v0.9.2-dev-main+mtp-zero' into 'v0.9.2-dev'

V0.9.2 dev main+mtp zero

See merge request dcutoolkit/deeplearing/vllm!325
parents cd42bf87 76e22965
...@@ -200,6 +200,7 @@ if TYPE_CHECKING: ...@@ -200,6 +200,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1298,6 +1299,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1298,6 +1299,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM": "VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")), ("true", "1")),
"VLLM_ZERO_OVERHEAD_ENHANCE":
lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -4,6 +4,8 @@ from enum import Enum ...@@ -4,6 +4,8 @@ from enum import Enum
import os import os
import torch import torch
import vllm.envs as envs import vllm.envs as envs
import triton
import triton.language as tl
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1' zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
...@@ -69,3 +71,60 @@ def zero_overhead_stream(target_device): ...@@ -69,3 +71,60 @@ def zero_overhead_stream(target_device):
if target_device not in alloc_stream.keys(): if target_device not in alloc_stream.keys():
alloc_stream[target_device] = torch.cuda.Stream(device=target_device) alloc_stream[target_device] = torch.cuda.Stream(device=target_device)
return alloc_stream[target_device] return alloc_stream[target_device]
@triton.jit
def fused_last_valid_scatter_kernel(
last_ids_ptr, # [B, T]
input_ids_ptr, # [N]
update_req_ptr, # [U]
input_pos_ptr, # [U]
stride0,
stride1,
T,
BLOCK_T: tl.constexpr,
):
pid = tl.program_id(0)
# indices
req_idx = tl.load(update_req_ptr + pid)
input_pos = tl.load(input_pos_ptr + pid)
# load row
offs = tl.arange(0, BLOCK_T)
mask = offs < T
row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1
vals = tl.load(row_ptr, mask=mask, other=-1)
idx = tl.where(vals != -1, offs, -1)
last_idx = tl.max(idx, axis=0)
# load last token
last_val = tl.load(
last_ids_ptr + req_idx * stride0 + last_idx * stride1,
mask=last_idx >= 0,
other=0,
)
# scatter
tl.store(input_ids_ptr + input_pos, last_val)
def fused_update_input_ids_impl(
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
):
B, T = last_sampled_token_ids.shape
U = update_req_indices.numel()
BLOCK_T = 1024
assert T <= BLOCK_T
grid = (U,)
fused_last_valid_scatter_kernel[grid](
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
last_sampled_token_ids.stride(0),
last_sampled_token_ids.stride(1),
T,
BLOCK_T=BLOCK_T,
)
\ No newline at end of file
import torch import torch
from collections import defaultdict from collections import defaultdict
from typing import Optional from typing import Optional
...@@ -9,18 +8,19 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs ...@@ -9,18 +8,19 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm import envs
requsets_valid_token_len = {} requsets_valid_token_len = {}
def check_stop(request: Request, def check_stop(request: Request,
max_model_len: int, max_model_len: int,
pooler_output: Optional[torch.Tensor] = None, pooler_output: Optional[torch.Tensor] = None,
use_valid_token_len:bool = False) -> bool: use_valid_token_len: bool = False,
last_token_offset: Optional[int] = 0) -> bool:
if use_valid_token_len: if use_valid_token_len:
if request.request_id not in requsets_valid_token_len: if request.request_id not in requsets_valid_token_len:
requsets_valid_token_len[request.request_id] = 0 requsets_valid_token_len[request.request_id] = 0
return False return False
valid_output_len = requsets_valid_token_len[request.request_id] valid_output_len = requsets_valid_token_len[request.request_id] - last_token_offset
else: else:
valid_output_len = request.num_output_tokens valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len valid_num_tokens = request.num_prompt_tokens + valid_output_len
...@@ -72,6 +72,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -72,6 +72,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
continue continue
request = scheduler.requests[req_id] request = scheduler.requests[req_id]
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx] generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
if not generated_token_ids:
continue
if req_id not in requsets_valid_token_len: if req_id not in requsets_valid_token_len:
requsets_valid_token_len[req_id] = 0 requsets_valid_token_len[req_id] = 0
valid_output_len = requsets_valid_token_len[req_id] valid_output_len = requsets_valid_token_len[req_id]
...@@ -81,6 +83,13 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -81,6 +83,13 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
request._all_token_ids[fix_offset] = generated_token_ids request._all_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1 requsets_valid_token_len[req_id] += 1
generated_token_ids = [generated_token_ids] generated_token_ids = [generated_token_ids]
elif envs.VLLM_ZERO_OVERHEAD_ENHANCE:
requsets_valid_token_len[req_id] += len(generated_token_ids)
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
request.num_computed_tokens = request.num_tokens - 1
else: else:
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0: if valid_output_end == 0:
...@@ -89,9 +98,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -89,9 +98,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
else: else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids) requsets_valid_token_len[req_id] += len(generated_token_ids)
stopped = False stopped = False
new_logprobs = None new_logprobs = None
new_token_ids = generated_token_ids new_token_ids = generated_token_ids
...@@ -100,7 +109,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -100,7 +109,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
for num_new, output_token_id in enumerate(new_token_ids, 1): for num_new, output_token_id in enumerate(new_token_ids, 1):
stopped = check_stop(request, scheduler.max_model_len, True) stopped = check_stop(request, scheduler.max_model_len, use_valid_token_len=True,
last_token_offset=len(new_token_ids) - num_new)
if stopped: if stopped:
kv_transfer_params = scheduler._free_request(request) kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed. del new_token_ids[num_new:] # Trim new tokens if needed.
...@@ -110,7 +120,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -110,7 +120,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if pooler_outputs: if pooler_outputs:
pooler_output = pooler_outputs[req_idx] pooler_output = pooler_outputs[req_idx]
stopped = check_stop(request, scheduler.max_model_len, stopped = check_stop(request, scheduler.max_model_len,
pooler_output, True) pooler_output=pooler_output, use_valid_token_len=True)
if stopped: if stopped:
kv_transfer_params = scheduler._free_request(request) kv_transfer_params = scheduler._free_request(request)
...@@ -192,6 +202,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -192,6 +202,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[ generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else [] req_index] if sampled_token_ids else []
if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
request.num_computed_tokens == request.num_prompt_tokens):
generated_token_ids = generated_token_ids[:1]
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
...@@ -202,13 +215,20 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -202,13 +215,20 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# num_computed_tokens is decreased by the number of rejected # num_computed_tokens is decreased by the number of rejected
# tokens, where is given by: # tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new = len(generated_token_ids)
len(generated_token_ids)) if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
model_runner_output.fix_req_ids and
req_id in model_runner_output.fix_req_ids and
request.num_computed_tokens > request.num_prompt_tokens + num_new):
req_idx = model_runner_output.fix_req_ids.index(req_id)
num_new = len(model_runner_output.fix_sampled_token_ids[req_idx])
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new)
request.num_computed_tokens -= num_tokens_rejected request.num_computed_tokens -= num_tokens_rejected
spec_decoding_stats = scheduler.make_spec_decoding_stats( spec_decoding_stats = scheduler.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids), num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1) num_accepted_tokens=num_new - 1)
# NOTE(woosuk): This has to be executed after updating # NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`. # `request.num_computed_tokens`.
...@@ -231,7 +251,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -231,7 +251,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if model_runner_output.is_output_valid: if model_runner_output.is_output_valid:
stopped = check_stop(request, scheduler.max_model_len, stopped = check_stop(request, scheduler.max_model_len,
False) use_valid_token_len=False)
if stopped: if stopped:
kv_transfer_params = scheduler._free_request(request) kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed. del new_token_ids[num_new:] # Trim new tokens if needed.
...@@ -242,8 +262,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -242,8 +262,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
if model_runner_output.is_output_valid: if model_runner_output.is_output_valid:
pooler_output = pooler_outputs[req_index] pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len, stopped = check_stop(request, scheduler.max_model_len,
pooler_output, pooler_output,
False) use_valid_token_len=False)
if stopped: if stopped:
kv_transfer_params = scheduler._free_request(request) kv_transfer_params = scheduler._free_request(request)
...@@ -350,10 +370,10 @@ def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]: ...@@ -350,10 +370,10 @@ def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]:
model_output = core.execute_model(scheduler_output) model_output = core.execute_model(scheduler_output)
if isinstance(model_output, ZeroV1ModelRunnerOutput): if isinstance(model_output, ZeroV1ModelRunnerOutput):
engine_core_outputs = zero_overhead_update_from_output(core.scheduler, engine_core_outputs = zero_overhead_update_from_output(core.scheduler,
scheduler_output, model_output) # type: ignore scheduler_output, model_output) # type: ignore
else: else:
engine_core_outputs = core.scheduler.update_from_output( engine_core_outputs = core.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore scheduler_output, model_output) # type: ignore
return (engine_core_outputs, return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0) scheduler_output.total_num_scheduled_tokens > 0)
\ No newline at end of file
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
import numpy as np import numpy as np
...@@ -23,7 +22,7 @@ from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput ...@@ -23,7 +22,7 @@ from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.v1.spec_decode.utils import DraftProbs from vllm.v1.spec_decode.utils import DraftProbs
from vllm.zero_overhead.utils import fused_update_input_ids_impl
class V1ZeroModelRunner(GPUModelRunner): class V1ZeroModelRunner(GPUModelRunner):
def __init__(self, vllm_config, device): def __init__(self, vllm_config, device):
...@@ -39,6 +38,9 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -39,6 +38,9 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_draft_event = torch.cuda.Event(enable_timing=False) self.last_draft_event = torch.cuda.Event(enable_timing=False)
self.spec_sampler_event = torch.cuda.Event(enable_timing=False) self.spec_sampler_event = torch.cuda.Event(enable_timing=False)
self.spec_scheduler_max_num_tokens = 0 self.spec_scheduler_max_num_tokens = 0
self.fix_req_ids = None
self.fix_sampled_token_ids = None
if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer): if hasattr(self, 'drafter') and isinstance(self.drafter, EagleProposer):
self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device, self.drafter = V1ZeroEagleProposer(self.vllm_config, self.device,
self) self)
...@@ -81,6 +83,37 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -81,6 +83,37 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens, arange = self._get_cumsum_and_arange( cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens) num_scheduled_tokens)
if (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
self.speculative_config and self.last_sampler_host_tokens != None):
self.fix_req_ids = self.last_sampled_req_ids
self.last_sampler_event.synchronize() # 等上一轮主模型结束
num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
if num_gen_tokens == 1:
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
else:
self.fix_sampled_token_ids = self.rejection_sampler.parse_output(
self.last_sampler_host_tokens,
self.input_batch.vocab_size,
)
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx])
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
new_end_idx = start_idx + num_accepted_tokens
# # 更新token统计数据
self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx
self.input_batch.num_tokens[new_req_idx] = new_end_idx
self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = (
self.fix_sampled_token_ids)[req_idx]
self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx)
if req_id in self.requests:
req_state = self.requests[req_id]
req_state.output_token_ids.extend(self.fix_sampled_token_ids[req_idx])
# Get positions. # Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens] positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices], np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
...@@ -272,7 +305,7 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -272,7 +305,7 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices = [] input_ids_indices = []
token_idx = 0 token_idx = 0
if self.last_sampled_token_ids is not None: if self.last_sampled_token_ids is not None:
sampled_tokens_num = self.last_sampled_token_ids.shape[1] sampled_tokens_num = 1 if self.speculative_config else self.last_sampled_token_ids.shape[1]
for req_id in req_ids: for req_id in req_ids:
if req_id in self.last_sampled_req_ids: if req_id in self.last_sampled_req_ids:
req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num req_idx = self.last_sampled_req_ids.index(req_id) * sampled_tokens_num
...@@ -286,9 +319,14 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -286,9 +319,14 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32, input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
self.device, self.device,
True) True)
last_sampled_token_ids = self.last_sampled_token_ids.flatten() if envs.VLLM_ZERO_OVERHEAD_ENHANCE and self.speculative_config:
for i in range(sampled_tokens_num): fused_update_input_ids_impl(self.last_sampled_token_ids,input_ids,
input_ids[input_ids_indices_tensor + i] = last_sampled_token_ids[update_req_indices_tensor + i] update_req_indices_tensor,input_ids_indices_tensor)
else:
last_sampled_token_ids = self.last_sampled_token_ids.flatten()
for i in range(sampled_tokens_num):
input_ids[input_ids_indices_tensor + i] = (
last_sampled_token_ids)[update_req_indices_tensor + i]
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
...@@ -660,21 +698,50 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -660,21 +698,50 @@ class V1ZeroModelRunner(GPUModelRunner):
scheduler_output, scheduler_output,
) )
fix_req_ids = None
fix_sampled_token_ids = None
fix_draft_token_ids = None fix_draft_token_ids = None
fix_draft_req_ids = self.last_sampled_req_ids fix_draft_req_ids = self.last_sampled_req_ids
is_output_valid = False is_output_valid = False
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1] over_head_enhance = (envs.VLLM_ZERO_OVERHEAD_ENHANCE and
self.speculative_config is not None)
if over_head_enhance:
# if not self.speculative_config:
# self.fix_req_ids = self.last_sampled_req_ids
# if self.last_sampler_host_tokens is not None:
# self.last_sampler_event.synchronize()
# self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
# for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
# if start_idx == -1:
# self.fix_sampled_token_ids[req_idx].clear()
# continue
# req_id = self.fix_req_ids[req_idx]
# if req_id in self.input_batch.req_ids:
# new_req_idx = self.input_batch.req_ids.index(req_id)
# self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
# for req_idx, req_id in enumerate(self.fix_req_ids):
# if req_id in self.requests:
# req_state = self.requests[req_id]
# token_idx = self.last_sampled_token_lens[req_idx]
# if token_idx == -1:
# continue
# fix_len = len(self.fix_sampled_token_ids[req_idx])
# req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
if not self.speculative_config: if not self.speculative_config:
# Speculative decoding is not enabled. # Speculative decoding is not enabled.
spec_token_ids = None spec_token_ids = None
fix_draft_req_ids = None fix_draft_req_ids = None
else: else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True) if not over_head_enhance:
self.spec_sampler_event.record() sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
self.spec_sampler_event.record()
if self.last_draft_host_tokens is not None: if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize() self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist() fix_draft_token_ids = self.last_draft_host_tokens.tolist()
...@@ -695,45 +762,47 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -695,45 +762,47 @@ class V1ZeroModelRunner(GPUModelRunner):
attn_metadata, attn_metadata,
) )
if self.speculative_config: if not over_head_enhance:
self.spec_sampler_event.synchronize() if self.speculative_config:
if max_gen_len == 1: self.spec_sampler_event.synchronize()
valid_sampled_token_ids = sampled_token_ids_cpu.tolist() max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids_cpu.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids_cpu,
self.input_batch.vocab_size,
)
self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None
is_output_valid = True
else: else:
# Includes spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output( self.fix_req_ids = self.last_sampled_req_ids
sampled_token_ids_cpu, if self.last_sampler_host_tokens != None:
self.input_batch.vocab_size, self.last_sampler_event.synchronize()
) self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
self.last_sampler_host_tokens = None for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
self.last_sampled_token_ids = None if start_idx == -1:
is_output_valid = True continue
else: req_id = self.fix_req_ids[req_idx]
# No spec decode tokens. if req_id in self.input_batch.req_ids:
fix_req_ids = self.last_sampled_req_ids new_req_idx = self.input_batch.req_ids.index(req_id)
if self.last_sampler_host_tokens != None: self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_event.synchronize() for req_idx, req_id in enumerate(self.fix_req_ids):
fix_sampled_token_ids = self.last_sampler_host_tokens.tolist() if req_id in self.requests:
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record: req_state = self.requests[req_id]
if start_idx == -1: token_idx = self.last_sampled_token_lens[req_idx]
continue if token_idx == -1:
req_id = fix_req_ids[req_idx] continue
if req_id in self.input_batch.req_ids: fix_len = len(self.fix_sampled_token_ids[req_idx])
new_req_idx = self.input_batch.req_ids.index(req_id) req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = fix_sampled_token_ids[req_idx] self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
for req_idx, req_id in enumerate(fix_req_ids): self.last_sampler_event.record()
if req_id in self.requests: self.last_sampled_token_ids = sampled_token_ids
req_state = self.requests[req_id] valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
token_idx = self.last_sampled_token_lens[req_idx]
if token_idx == -1:
continue
fix_len = len(fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
self.last_sampler_event.record()
self.last_sampled_token_ids = sampled_token_ids
valid_sampled_token_ids = np.ones(sampled_token_ids.shape, dtype=int).tolist()
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
...@@ -767,12 +836,11 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -767,12 +836,11 @@ class V1ZeroModelRunner(GPUModelRunner):
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx]) self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
if req_id in self.requests: if not over_head_enhance and req_id in self.requests:
req_state = self.requests[req_id] req_state = self.requests[req_id]
cache_output_len = len(req_state.output_token_ids) cache_output_len = len(req_state.output_token_ids)
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
self.last_sampled_token_lens.append(cache_output_len) self.last_sampled_token_lens.append(cache_output_len)
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
...@@ -791,10 +859,10 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -791,10 +859,10 @@ class V1ZeroModelRunner(GPUModelRunner):
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits, num_nans_in_logits=num_nans_in_logits,
fix_req_ids = fix_req_ids, fix_req_ids=self.fix_req_ids,
fix_sampled_token_ids = fix_sampled_token_ids, fix_sampled_token_ids=self.fix_sampled_token_ids,
fix_draft_tokens_ids = fix_draft_token_ids, fix_draft_tokens_ids=fix_draft_token_ids,
fix_draft_req_ids = fix_draft_req_ids, fix_draft_req_ids=fix_draft_req_ids,
is_output_valid=is_output_valid is_output_valid=is_output_valid
) )
return model_output return model_output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment