Commit 96197e48 authored by jujl1's avatar jujl1
Browse files

fix: support chunk-prefill and fix bug in check_stop

parent 89639c96
......@@ -14,12 +14,13 @@ requsets_valid_token_len = {}
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None,
use_valid_token_len:bool = False) -> bool:
use_valid_token_len: bool = False,
last_token_offset: Optional[int] = 0) -> bool:
if use_valid_token_len:
if request.request_id not in requsets_valid_token_len:
requsets_valid_token_len[request.request_id] = 0
return False
valid_output_len = requsets_valid_token_len[request.request_id]
valid_output_len = requsets_valid_token_len[request.request_id] - last_token_offset
else:
valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len
......@@ -100,7 +101,8 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for num_new, output_token_id in enumerate(new_token_ids, 1):
stopped = check_stop(request, scheduler.max_model_len, True)
stopped = check_stop(request, scheduler.max_model_len, use_valid_token_len=True,
last_token_offset=len(new_token_ids) - num_new)
if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
......
......@@ -38,29 +38,22 @@ def fused_last_valid_scatter_kernel(
BLOCK_T: tl.constexpr,
):
pid = tl.program_id(0)
# indices
req_idx = tl.load(update_req_ptr + pid)
input_pos = tl.load(input_pos_ptr + pid)
# load row
offs = tl.arange(0, BLOCK_T)
mask = offs < T
row_ptr = last_ids_ptr + req_idx * stride0 + offs * stride1
vals = tl.load(row_ptr, mask=mask, other=-1)
# ✅ 正确做法:index reduction
idx = tl.where(vals != -1, offs, -1)
last_idx = tl.max(idx, axis=0)
# load last token
last_val = tl.load(
last_ids_ptr + req_idx * stride0 + last_idx * stride1,
mask=last_idx >= 0,
other=0,
)
# scatter
tl.store(input_ids_ptr + input_pos, last_val)
......@@ -138,9 +131,6 @@ class V1ZeroModelRunner(GPUModelRunner):
)
for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
if start_idx == -1:
self.fix_sampled_token_ids[req_idx].clear()
else:
num_accepted_tokens = len(self.fix_sampled_token_ids[req_idx])
req_id = self.fix_req_ids[req_idx]
if req_id in self.input_batch.req_ids:
......@@ -779,6 +769,31 @@ class V1ZeroModelRunner(GPUModelRunner):
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
......@@ -789,12 +804,9 @@ class V1ZeroModelRunner(GPUModelRunner):
self.last_sampled_token_lens = []
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[req_idx]
self.last_sampled_req_ids.append(req_id)
cache_output_len = -1
if not sampled_ids:
self.last_sampled_token_lens.append(-1)
self.token_ids_cpu_fix_record.append([req_idx, -1, -1])
continue
self.last_sampled_req_ids.append(req_id)
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
......@@ -809,32 +821,6 @@ class V1ZeroModelRunner(GPUModelRunner):
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
fix_draft_req_ids = None
else:
sampled_token_ids_cpu = sampled_token_ids.to('cpu', non_blocking=True)
if self.last_draft_host_tokens is not None:
self.last_draft_event.synchronize()
fix_draft_token_ids = self.last_draft_host_tokens.tolist()
mask = (sampled_token_ids == -1)
mask_int = mask.int()
first_neg_one_indices = torch.argmax(mask_int, dim=1)
num_accepted_tokens_tensor = torch.where(torch.any(mask, dim=1), first_neg_one_indices, sampled_token_ids.size(1)) - 1
spec_token_ids = self.propose_draft_token_ids(
scheduler_output,
num_accepted_tokens_tensor,
sampled_token_ids,
sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
attn_metadata,
)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment