Unverified Commit 963dc0b8 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor optimization for eagle input processing (#32535)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 8cc26acd
...@@ -827,20 +827,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -827,20 +827,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.speculator is not None assert self.speculator is not None
last_sampled_tokens = self.req_states.last_sampled_tokens[
input_batch.idx_mapping
]
next_prefill_tokens = self.req_states.next_prefill_tokens[
input_batch.idx_mapping
]
draft_tokens = self.speculator.propose( draft_tokens = self.speculator.propose(
input_batch, input_batch,
last_hidden_states, last_hidden_states,
aux_hidden_states, aux_hidden_states,
num_sampled, num_sampled,
num_rejected, num_rejected,
last_sampled_tokens, self.req_states.last_sampled_tokens,
next_prefill_tokens, self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu, self.sampler.sampling_states.seeds.gpu,
) )
......
...@@ -195,9 +195,9 @@ class EagleSpeculator: ...@@ -195,9 +195,9 @@ class EagleSpeculator:
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
# [num_reqs] # [max_num_reqs]
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [num_reqs] # [max_num_reqs]
next_prefill_tokens: torch.Tensor, next_prefill_tokens: torch.Tensor,
# [max_num_reqs] # [max_num_reqs]
temperature: torch.Tensor, temperature: torch.Tensor,
...@@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel( ...@@ -320,6 +320,7 @@ def _prepare_eagle_inputs_kernel(
eagle_positions_ptr, eagle_positions_ptr,
target_input_ids_ptr, target_input_ids_ptr,
target_positions_ptr, target_positions_ptr,
idx_mapping_ptr,
last_sampled_ptr, last_sampled_ptr,
next_prefill_tokens_ptr, next_prefill_tokens_ptr,
num_sampled_ptr, num_sampled_ptr,
...@@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel( ...@@ -328,6 +329,8 @@ def _prepare_eagle_inputs_kernel(
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
query_start = tl.load(query_start_loc_ptr + batch_idx) query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1) query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start query_len = query_end - query_start
...@@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel( ...@@ -338,11 +341,11 @@ def _prepare_eagle_inputs_kernel(
num_sampled = tl.load(num_sampled_ptr + batch_idx) num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0: if num_sampled > 0:
next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32) next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
else: else:
# Chunked prefilling. # Chunked prefilling.
# Get the next prefill token. # Get the next prefill token.
next_token = tl.load(next_prefill_tokens_ptr + batch_idx) next_token = tl.load(next_prefill_tokens_ptr + req_state_idx)
# Shift target_input_ids by one. # Shift target_input_ids by one.
for i in range(1, query_len, BLOCK_SIZE): for i in range(1, query_len, BLOCK_SIZE):
...@@ -370,9 +373,9 @@ def prepare_eagle_inputs( ...@@ -370,9 +373,9 @@ def prepare_eagle_inputs(
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
# [num_reqs] # [max_num_reqs]
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [num_reqs] # [max_num_reqs]
next_prefill_tokens: torch.Tensor, next_prefill_tokens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
...@@ -387,6 +390,7 @@ def prepare_eagle_inputs( ...@@ -387,6 +390,7 @@ def prepare_eagle_inputs(
input_buffers.positions, input_buffers.positions,
input_batch.input_ids, input_batch.input_ids,
input_batch.positions, input_batch.positions,
input_batch.idx_mapping,
last_sampled, last_sampled,
next_prefill_tokens, next_prefill_tokens,
num_sampled, num_sampled,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment