Commit 96197e48 authored by jujl1's avatar jujl1
Browse files

fix: support chunk-prefill and fix bug in check_stop

parent 89639c96
...@@ -14,12 +14,13 @@ requsets_valid_token_len = {} ...@@ -14,12 +14,13 @@ requsets_valid_token_len = {}
def check_stop(request: Request, def check_stop(request: Request,
max_model_len: int, max_model_len: int,
pooler_output: Optional[torch.Tensor] = None, pooler_output: Optional[torch.Tensor] = None,
use_valid_token_len:bool = False) -> bool: use_valid_token_len: bool = False,
last_token_offset: Optional[int] = 0) -> bool:
if use_valid_token_len: if use_valid_token_len:
if request.request_id not in requsets_valid_token_len: if request.request_id not in requsets_valid_token_len:
requsets_valid_token_len[request.request_id] = 0 requsets_valid_token_len[request.request_id] = 0
return False return False
valid_output_len = requsets_valid_token_len[request.request_id] valid_output_len = requsets_valid_token_len[request.request_id] - last_token_offset
else: else:
valid_output_len = request.num_output_tokens valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len valid_num_tokens = request.num_prompt_tokens + valid_output_len
...@@ -100,7 +101,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -100,7 +101,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
for num_new, output_token_id in enumerate(new_token_ids, 1): for num_new, output_token_id in enumerate(new_token_ids, 1):
stopped = check_stop(request, scheduler.max_model_len, True) stopped = check_stop(request, scheduler.max_model_len, use_valid_token_len=True,
last_token_offset=len(new_token_ids) - num_new)
if stopped: if stopped:
kv_transfer_params = scheduler._free_request(request) kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed. del new_token_ids[num_new:] # Trim new tokens if needed.
......
...@@ -38,29 +38,22 @@ def fused_last_valid_scatter_kernel( ...@@ -38,29 +38,22 @@ def fused_last_valid_scatter_kernel(
BLOCK_T: tl.constexpr, BLOCK_T: tl.constexpr,
): ):
pid = tl.program_id(0) pid = tl.program_id(0)
# indices # indices
req_idx = tl.load(update_req_ptr + pid) req_idx = tl.load(update_req_ptr + pid)
input_pos = tl.load(input_pos_ptr + pid) input_pos = tl.load(input_pos_ptr + pid)
# load row # load row
offs = tl.arange(0, BLOCK_T) offs = tl.arange(0, BLOCK_T)
mask = offs < T mask = offs < T
row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1 row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1
vals = tl.load(row_ptr, mask=mask, other=-1) vals = tl.load(row_ptr, mask=mask, other=-1)
# ✅ 正确做法:index reduction
idx = tl.where(vals != -1, offs, -1) idx = tl.where(vals != -1, offs, -1)
last_idx = tl.max(idx, axis=0) last_idx = tl.max(idx, axis=0)
# load last token # load last token
last_val = tl.load( last_val = tl.load(
last_ids_ptr + req_idx * stride0 + last_idx * stride1, last_ids_ptr + req_idx * stride0 + last_idx * stride1,
mask=last_idx >= 0, mask=last_idx >= 0,
other=0, other=0,
) )
# scatter # scatter
tl.store(input_ids_ptr + input_pos, last_val) tl.store(input_ids_ptr + input_pos, last_val)
...@@ -138,23 +131,20 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -138,23 +131,20 @@ class V1ZeroModelRunner(GPUModelRunner):
) )
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record: for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1: num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx])
self.fix_sampled_token_ids[req_idx].clear() req_id = self.fix_req_ids[req_idx]
else: if req_id in self.input_batch.req_ids:
num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx]) new_req_idx = self.input_batch.req_ids.index(req_id)
req_id = self.fix_req_ids[req_idx] new_end_idx = start_idx + num_accepted_tokens
if req_id in self.input_batch.req_ids: # # 更新token统计数据
new_req_idx = self.input_batch.req_ids.index(req_id) self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx
new_end_idx = start_idx + num_accepted_tokens self.input_batch.num_tokens[new_req_idx] = new_end_idx
# # 更新token统计数据 self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = self.fix_sampled_token_ids[
self.input_batch.num_tokens_no_spec[new_req_idx] = new_end_idx req_idx]
self.input_batch.num_tokens[new_req_idx] = new_end_idx self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx)
self.input_batch.token_ids_cpu[new_req_idx, start_idx:new_end_idx] = self.fix_sampled_token_ids[ if req_id in self.requests:
req_idx] req_state = self.requests[req_id]
self.input_batch.num_computed_tokens_cpu[new_req_idx] -= (end_idx - new_end_idx) req_state.output_token_ids.extend(self.fix_sampled_token_ids[req_idx])
if req_id in self.requests:
req_state = self.requests[req_id]
req_state.output_token_ids.extend(self.fix_sampled_token_ids[req_idx])
# Get positions. # Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens] positions_np = self.positions_np[:total_num_scheduled_tokens]
...@@ -779,6 +769,31 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -779,6 +769,31 @@ class V1ZeroModelRunner(GPUModelRunner):
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Cache the sampled tokens in the model runner, so that the scheduler # Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back. # doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends # NOTE(woosuk): As an exception, when using PP, the scheduler sends
...@@ -789,12 +804,9 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -789,12 +804,9 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_lens = [] self.last_sampled_token_lens = []
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[req_idx] req_id = self.input_batch.req_ids[req_idx]
self.last_sampled_req_ids.append(req_id)
cache_output_len = -1
if not sampled_ids: if not sampled_ids:
self.last_sampled_token_lens.append(-1)
self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
continue continue
self.last_sampled_req_ids.append(req_id)
start_idx = self.input_batch.num_tokens_no_spec[req_idx] start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids) end_idx = start_idx + len(sampled_ids)
...@@ -809,32 +821,6 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -809,32 +821,6 @@ class V1ZeroModelRunner(GPUModelRunner):
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated. # Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment