Unverified Commit f32c7d6f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Simplify Eagle bookkeeping with num_rejected (#29347)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 3cfa63ad
...@@ -344,8 +344,8 @@ def _post_update_kernel( ...@@ -344,8 +344,8 @@ def _post_update_kernel(
sampled_tokens_ptr, sampled_tokens_ptr,
sampled_tokens_stride, sampled_tokens_stride,
num_sampled_ptr, num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr, query_start_loc_ptr,
cu_num_logits_ptr,
): ):
req_id = tl.program_id(0) req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id) req_state_idx = tl.load(idx_mapping_ptr + req_id)
...@@ -360,17 +360,10 @@ def _post_update_kernel( ...@@ -360,17 +360,10 @@ def _post_update_kernel(
query_start = tl.load(query_start_loc_ptr + req_id) query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1) query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start query_len = query_end - query_start
num_rejected = tl.load(num_rejected_ptr + req_id)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
num_computed += query_len num_computed += query_len - num_rejected
# Consider the rejected tokens in spec decoding.
if num_sampled > 0:
# NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills.
logits_start = tl.load(cu_num_logits_ptr + req_id)
logits_end = tl.load(cu_num_logits_ptr + req_id + 1)
num_logits = logits_end - logits_start
num_rejected = num_logits - num_sampled
num_computed -= num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed) tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
...@@ -385,10 +378,10 @@ def post_update( ...@@ -385,10 +378,10 @@ def post_update(
sampled_tokens: torch.Tensor, sampled_tokens: torch.Tensor,
# [num_reqs] # [num_reqs]
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [num_reqs + 1] # [num_reqs + 1]
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
) -> None: ) -> None:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
_post_update_kernel[(num_reqs,)]( _post_update_kernel[(num_reqs,)](
...@@ -398,7 +391,7 @@ def post_update( ...@@ -398,7 +391,7 @@ def post_update(
sampled_tokens, sampled_tokens,
sampled_tokens.stride(0), sampled_tokens.stride(0),
num_sampled, num_sampled,
num_rejected,
query_start_loc, query_start_loc,
cu_num_logits,
num_warps=1, num_warps=1,
) )
...@@ -46,7 +46,10 @@ from vllm.v1.worker.gpu.input_batch import ( ...@@ -46,7 +46,10 @@ from vllm.v1.worker.gpu.input_batch import (
) )
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
get_num_rejected,
rejection_sample,
)
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
...@@ -311,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -311,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
) )
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device) num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
self.propose_draft( self.propose_draft(
input_batch=input_batch, input_batch=input_batch,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
last_hidden_states=hidden_states, last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states, aux_hidden_states=aux_hidden_states,
num_sampled=num_sampled, num_sampled=num_sampled,
num_rejected=num_rejected,
) )
@torch.inference_mode() @torch.inference_mode()
...@@ -606,7 +611,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -606,7 +611,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch, input_batch: InputBatch,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None, grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor]: ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None: if grammar_output is not None:
...@@ -632,6 +637,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -632,6 +637,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# No draft tokens (common case). # No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not. # 0 if chunked-prefilling, 1 if not.
num_sampled = (~is_chunked_prefilling).int() num_sampled = (~is_chunked_prefilling).int()
num_rejected = torch.zeros_like(num_sampled)
else: else:
# Draft tokens for spec decoding. # Draft tokens for spec decoding.
input_ids = input_batch.input_ids[input_batch.logits_indices] input_ids = input_batch.input_ids[input_batch.logits_indices]
...@@ -642,9 +648,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -642,9 +648,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_speculative_steps, self.num_speculative_steps,
) )
num_sampled *= ~is_chunked_prefilling num_sampled *= ~is_chunked_prefilling
num_rejected = get_num_rejected(
input_batch.cu_num_logits,
num_sampled,
)
sampler_output.sampled_token_ids = sampled_tokens sampler_output.sampled_token_ids = sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding. # TODO(woosuk): Support logprobs with spec decoding.
return sampler_output, num_sampled return sampler_output, num_sampled, num_rejected
def compute_prompt_logprobs( def compute_prompt_logprobs(
self, self,
...@@ -750,6 +760,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -750,6 +760,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch, input_batch: InputBatch,
sampled_tokens: torch.Tensor, sampled_tokens: torch.Tensor,
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> None: ) -> None:
# Update the number of computed tokens. # Update the number of computed tokens.
post_update( post_update(
...@@ -758,8 +769,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -758,8 +769,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
sampled_tokens, sampled_tokens,
num_sampled, num_sampled,
num_rejected,
input_batch.query_start_loc, input_batch.query_start_loc,
input_batch.cu_num_logits,
) )
# Update the number of computed prefill tokens. # Update the number of computed prefill tokens.
...@@ -779,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -779,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_hidden_states: torch.Tensor, last_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None, aux_hidden_states: list[torch.Tensor] | None,
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
idx_mapping_np = input_batch.idx_mapping_np idx_mapping_np = input_batch.idx_mapping_np
...@@ -800,6 +812,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -800,6 +812,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_hidden_states, last_hidden_states,
aux_hidden_states, aux_hidden_states,
num_sampled, num_sampled,
num_rejected,
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
next_prefill_tokens, next_prefill_tokens,
) )
...@@ -958,7 +971,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -958,7 +971,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.execute_model_state = None # type: ignore self.execute_model_state = None # type: ignore
assert sampling_metadata is not None assert sampling_metadata is not None
sampler_output, num_sampled_tokens = self.sample( sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output hidden_states, input_batch, sampling_metadata, grammar_output
) )
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
...@@ -979,7 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -979,7 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
async_output = AsyncOutput( async_output = AsyncOutput(
model_runner_output=model_runner_output, model_runner_output=model_runner_output,
sampler_output=sampler_output, sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens, num_sampled_tokens=num_sampled,
copy_stream=self.output_copy_stream, copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event, copy_event=self.output_copy_event,
) )
...@@ -990,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -990,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This sequencing may slightly reduce latency as async D2H copy does not # This sequencing may slightly reduce latency as async D2H copy does not
# need to wait for the postprocess to finish. # need to wait for the postprocess to finish.
self.postprocess( self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled_tokens input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
) )
if self.do_spec_decode: if self.do_spec_decode:
_ = self.propose_draft( _ = self.propose_draft(
...@@ -998,7 +1011,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -998,7 +1011,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata, sampling_metadata,
hidden_states, hidden_states,
None, # aux_hidden_states None, # aux_hidden_states
num_sampled_tokens, num_sampled,
num_rejected,
) )
if self.use_async_scheduling: if self.use_async_scheduling:
......
...@@ -60,6 +60,8 @@ class EagleSpeculator: ...@@ -60,6 +60,8 @@ class EagleSpeculator:
aux_hidden_states: list[torch.Tensor] | None, aux_hidden_states: list[torch.Tensor] | None,
# [num_reqs] # [num_reqs]
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [max_num_reqs, 1] # [max_num_reqs, 1]
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
...@@ -84,6 +86,7 @@ class EagleSpeculator: ...@@ -84,6 +86,7 @@ class EagleSpeculator:
self.input_ids, self.input_ids,
input_batch, input_batch,
num_sampled, num_sampled,
num_rejected,
last_sampled, last_sampled,
next_prefill_tokens, next_prefill_tokens,
) )
...@@ -139,8 +142,8 @@ def _prepare_eagle_inputs_kernel( ...@@ -139,8 +142,8 @@ def _prepare_eagle_inputs_kernel(
last_sampled_ptr, last_sampled_ptr,
next_prefill_tokens_ptr, next_prefill_tokens_ptr,
num_sampled_ptr, num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr, query_start_loc_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) batch_idx = tl.program_id(0)
...@@ -149,17 +152,13 @@ def _prepare_eagle_inputs_kernel( ...@@ -149,17 +152,13 @@ def _prepare_eagle_inputs_kernel(
query_len = query_end - query_start query_len = query_end - query_start
# Get the true query length and next token after accounting for rejected tokens. # Get the true query length and next token after accounting for rejected tokens.
num_rejected = tl.load(num_rejected_ptr + batch_idx)
query_len -= num_rejected
num_sampled = tl.load(num_sampled_ptr + batch_idx) num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0: if num_sampled > 0:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32) next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
logits_start = tl.load(cu_num_logits_ptr + batch_idx)
logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = logits_end - logits_start
num_rejected = num_logits - num_sampled
query_len -= num_rejected
else: else:
# Chunked prefilling. # Chunked prefilling.
# Get the next prefill token. # Get the next prefill token.
...@@ -182,6 +181,8 @@ def prepare_eagle_inputs( ...@@ -182,6 +181,8 @@ def prepare_eagle_inputs(
input_batch: InputBatch, input_batch: InputBatch,
# [num_reqs] # [num_reqs]
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [max_num_reqs, 1] # [max_num_reqs, 1]
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [max_num_reqs] # [max_num_reqs]
...@@ -201,8 +202,8 @@ def prepare_eagle_inputs( ...@@ -201,8 +202,8 @@ def prepare_eagle_inputs(
last_sampled, last_sampled,
next_prefill_tokens, next_prefill_tokens,
num_sampled, num_sampled,
num_rejected,
input_batch.query_start_loc, input_batch.query_start_loc,
input_batch.cu_num_logits,
BLOCK_SIZE=1024, BLOCK_SIZE=1024,
) )
return last_token_indices return last_token_indices
...@@ -69,3 +69,15 @@ def rejection_sample( ...@@ -69,3 +69,15 @@ def rejection_sample(
num_warps=1, num_warps=1,
) )
return sampled, num_sampled return sampled, num_sampled
@torch.compile(dynamic=True)
def get_num_rejected(
cu_num_logits: torch.Tensor,
num_sampled: torch.Tensor,
) -> torch.Tensor:
num_logits = cu_num_logits[1:] - cu_num_logits[:-1]
num_rejected = num_logits - num_sampled
# No token is rejected for chunked prefills.
num_rejected *= num_sampled > 0
return num_rejected
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment