Commit bc945a5a authored by jujl1's avatar jujl1
Browse files

fix: 解决同时处理prefill和decode时的prefill请求token计数错误

parent 96197e48
...@@ -4,6 +4,8 @@ from enum import Enum ...@@ -4,6 +4,8 @@ from enum import Enum
import os import os
import torch import torch
import vllm.envs as envs import vllm.envs as envs
import triton
import triton.language as tl
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1' zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
...@@ -69,3 +71,60 @@ def zero_overhead_stream(target_device): ...@@ -69,3 +71,60 @@ def zero_overhead_stream(target_device):
if target_device not in alloc_stream.keys(): if target_device not in alloc_stream.keys():
alloc_stream[target_device] = torch.cuda.Stream(device=target_device) alloc_stream[target_device] = torch.cuda.Stream(device=target_device)
return alloc_stream[target_device] return alloc_stream[target_device]
@triton.jit
def fused_last_valid_scatter_kernel(
last_ids_ptr, # [B, T]
input_ids_ptr, # [N]
update_req_ptr, # [U]
input_pos_ptr, # [U]
stride0,
stride1,
T,
BLOCK_T: tl.constexpr,
):
pid = tl.program_id(0)
# indices
req_idx = tl.load(update_req_ptr + pid)
input_pos = tl.load(input_pos_ptr + pid)
# load row
offs = tl.arange(0, BLOCK_T)
mask = offs < T
row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1
vals = tl.load(row_ptr, mask=mask, other=-1)
idx = tl.where(vals != -1, offs, -1)
last_idx = tl.max(idx, axis=0)
# load last token
last_val = tl.load(
last_ids_ptr + req_idx * stride0 + last_idx * stride1,
mask=last_idx >= 0,
other=0,
)
# scatter
tl.store(input_ids_ptr + input_pos, last_val)
def fused_update_input_ids_impl(
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
):
B, T = last_sampled_token_ids.shape
U = update_req_indices.numel()
BLOCK_T = 1024
assert T <= BLOCK_T
grid = (U,)
fused_last_valid_scatter_kernel[grid](
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
last_sampled_token_ids.stride(0),
last_sampled_token_ids.stride(1),
T,
BLOCK_T=BLOCK_T,
)
\ No newline at end of file
...@@ -82,16 +82,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -82,16 +82,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
requsets_valid_token_len[req_id] += 1 requsets_valid_token_len[req_id] += 1
generated_token_ids = [generated_token_ids] generated_token_ids = [generated_token_ids]
else: else:
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids) requsets_valid_token_len[req_id] += len(generated_token_ids)
request._output_token_ids[:] = request._output_token_ids[:requsets_valid_token_len[req_id]] request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[:] = request._all_token_ids[:request.num_prompt_tokens + requsets_valid_token_len[req_id]] request._all_token_ids[fix_offset : ] = generated_token_ids
stopped = False stopped = False
new_logprobs = None new_logprobs = None
...@@ -194,7 +188,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -194,7 +188,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[ generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else[] req_index] if sampled_token_ids else[]
if request.num_computed_tokens == request.num_prompt_tokens:
generated_token_ids = generated_token_ids[:1]
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
...@@ -207,7 +202,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -207,7 +202,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_new = len(generated_token_ids) num_new = len(generated_token_ids)
if (model_runner_output.fix_req_ids and req_id in model_runner_output.fix_req_ids if (model_runner_output.fix_req_ids and req_id in model_runner_output.fix_req_ids
and request.num_computed_tokens >= request.num_prompt_tokens + num_new): and request.num_computed_tokens > request.num_prompt_tokens + num_new):
req_idx = model_runner_output.fix_req_ids.index(req_id) req_idx = model_runner_output.fix_req_ids.index(req_id)
num_new = len(model_runner_output.fix_sampled_token_ids[req_idx]) num_new = len(model_runner_output.fix_sampled_token_ids[req_idx])
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new) num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_new)
......
...@@ -22,40 +22,7 @@ from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput ...@@ -22,40 +22,7 @@ from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute_model
from vllm.v1.spec_decode.utils import DraftProbs from vllm.v1.spec_decode.utils import DraftProbs
from vllm.zero_overhead.utils import fused_update_input_ids_impl
import triton
import triton.language as tl
@triton.jit
def fused_last_valid_scatter_kernel(
last_ids_ptr, # [B, T]
input_ids_ptr, # [N]
update_req_ptr, # [U]
input_pos_ptr, # [U]
stride0,
stride1,
T,
BLOCK_T: tl.constexpr,
):
pid = tl.program_id(0)
# indices
req_idx = tl.load(update_req_ptr + pid)
input_pos = tl.load(input_pos_ptr + pid)
# load row
offs = tl.arange(0, BLOCK_T)
mask = offs < T
row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1
vals = tl.load(row_ptr, mask=mask, other=-1)
idx = tl.where(vals != -1, offs, -1)
last_idx = tl.max(idx, axis=0)
# load last token
last_val = tl.load(
last_ids_ptr + req_idx * stride0 + last_idx * stride1,
mask=last_idx >= 0,
other=0,
)
# scatter
tl.store(input_ids_ptr + input_pos, last_val)
class V1ZeroModelRunner(GPUModelRunner): class V1ZeroModelRunner(GPUModelRunner):
def __init__(self, vllm_config, device): def __init__(self, vllm_config, device):
...@@ -116,21 +83,21 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -116,21 +83,21 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens, arange = self._get_cumsum_and_arange( cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens) num_scheduled_tokens)
self.fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens != None:
self.last_sampler_event.synchronize() # 等上一轮主模型结束
if self.speculative_config: # 处理上一轮mtp
num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
if num_gen_tokens == 1:
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
else:
# Includes spec decode tokens.
self.fix_sampled_token_ids = self.rejection_sampler.parse_output(
self.last_sampler_host_tokens,
self.input_batch.vocab_size,
)
if self.speculative_config and self.last_sampler_host_tokens != None:
self.fix_req_ids = self.last_sampled_req_ids
self.last_sampler_event.synchronize() # 等上一轮主模型结束
num_gen_tokens = self.last_sampler_host_tokens.shape[-1]
if num_gen_tokens == 1:
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
else:
self.fix_sampled_token_ids = self.rejection_sampler.parse_output(
self.last_sampler_host_tokens,
self.input_batch.vocab_size,
)
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record: for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx]) num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx])
req_id = self.fix_req_ids[req_idx] req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids: if req_id in self.input_batch.req_ids:
...@@ -332,31 +299,6 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -332,31 +299,6 @@ class V1ZeroModelRunner(GPUModelRunner):
True) True)
last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int) last_draft_token_ids = self.last_draft_token_ids.flatten().to(torch.int)
input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor] input_ids[input_ids_indices_tensor] = last_draft_token_ids[update_req_indices_tensor]
def fused_update_input_ids(
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
):
B, T = last_sampled_token_ids.shape
U = update_req_indices.numel()
BLOCK_T = 1024
assert T <= BLOCK_T
grid = (U,)
fused_last_valid_scatter_kernel[grid](
last_sampled_token_ids,
input_ids,
update_req_indices,
input_ids_indices,
last_sampled_token_ids.stride(0),
last_sampled_token_ids.stride(1),
T,
BLOCK_T=BLOCK_T,
)
update_req_indices = [] update_req_indices = []
input_ids_indices = [] input_ids_indices = []
token_idx = 0 token_idx = 0
...@@ -374,13 +316,12 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -374,13 +316,12 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32, input_ids_indices_tensor = async_tensor_h2d(input_ids_indices, torch.int32,
self.device, self.device,
True) True)
fused_update_input_ids( if self.speculative_config:
self.last_sampled_token_ids, fused_update_input_ids_impl(self.last_sampled_token_ids,input_ids,
input_ids, update_req_indices_tensor,input_ids_indices_tensor)
update_req_indices_tensor, else:
input_ids_indices_tensor) last_sampled_token_ids = self.last_sampled_token_ids.flatten()
input_ids[input_ids_indices_tensor] =last_sampled_token_ids[update_req_indices_tensor]
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
...@@ -757,7 +698,26 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -757,7 +698,26 @@ class V1ZeroModelRunner(GPUModelRunner):
is_output_valid = False is_output_valid = False
# Get the valid generated tokens. # Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
if not self.speculative_config:
self.fix_req_ids = self.last_sampled_req_ids
if self.last_sampler_host_tokens is not None:
self.last_sampler_event.synchronize()
self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
continue
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
new_req_idx = self.input_batch.req_ids.index(req_id)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
for req_idx, req_id in enumerate(self.fix_req_ids):
if req_id in self.requests:
req_state = self.requests[req_id]
token_idx = self.last_sampled_token_lens[req_idx]
if token_idx == -1:
continue
fix_len = len(self.fix_sampled_token_ids[req_idx])
req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self.last_sampler_host_tokens = None self.last_sampler_host_tokens = None
self.last_sampled_token_ids = None self.last_sampled_token_ids = None
self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True) self.last_sampler_host_tokens = sampled_token_ids.to('cpu', non_blocking=True)
...@@ -804,10 +764,12 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -804,10 +764,12 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_lens = [] self.last_sampled_token_lens = []
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[req_idx] req_id = self.input_batch.req_ids[req_idx]
cache_output_len = -1
self.last_sampled_req_ids.append(req_id)
if not sampled_ids: if not sampled_ids:
self.last_sampled_token_lens.append(-1)
self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
continue continue
self.last_sampled_req_ids.append(req_id)
start_idx = self.input_batch.num_tokens_no_spec[req_idx] start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids) end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, ( assert end_idx <= self.max_model_len, (
...@@ -820,6 +782,11 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -820,6 +782,11 @@ class V1ZeroModelRunner(GPUModelRunner):
self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx]) self.token_ids_cpu_fix_record.append([req_idx, start_idx, end_idx])
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
if not self.speculative_config and req_id in self.requests:
req_state = self.requests[req_id]
cache_output_len = len(req_state.output_token_ids)
req_state.output_token_ids.extend(sampled_ids)
self.last_sampled_token_lens.append(cache_output_len)
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment