Unverified Commit c50e105a authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Avoid prepare prefill kernel launch overhead (#34780)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent a766b303
...@@ -614,16 +614,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -614,16 +614,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item() max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens. # Get prefill tokens if any.
prepare_prefill_inputs( if self.req_states.any_prefills(idx_mapping_np):
self.input_buffers.input_ids, prepare_prefill_inputs(
self.req_states.next_prefill_tokens, self.input_buffers.input_ids,
idx_mapping, self.req_states.next_prefill_tokens,
query_start_loc, idx_mapping,
self.req_states.all_token_ids.gpu, query_start_loc,
self.req_states.prefill_len.gpu, self.req_states.all_token_ids.gpu,
self.req_states.num_computed_tokens.gpu, self.req_states.prefill_len.gpu,
) self.req_states.num_computed_tokens.gpu,
)
# Prepare positions and seq_lens. # Prepare positions and seq_lens.
prepare_pos_seq_lens( prepare_pos_seq_lens(
......
...@@ -60,10 +60,7 @@ class RequestState: ...@@ -60,10 +60,7 @@ class RequestState:
# Last sampled tokens. # Last sampled tokens.
self.last_sampled_tokens = torch.zeros( self.last_sampled_tokens = torch.zeros(
self.max_num_reqs, self.max_num_reqs, 1, dtype=torch.int64, device=device
1,
dtype=torch.int64,
device=device,
) )
# Draft tokens. # Draft tokens.
...@@ -118,3 +115,9 @@ class RequestState: ...@@ -118,3 +115,9 @@ class RequestState:
return return
self.index_to_req_id.pop(req_idx, None) self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx) self.free_indices.append(req_idx)
def any_prefills(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(
self.num_computed_prefill_tokens[idx_mapping_np]
< self.prefill_len.np[idx_mapping_np]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment