Unverified Commit dcee9be9 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2] Fix draft logits not populated during cudagraph replay (#37639)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent bd8c4c07
...@@ -195,7 +195,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -195,7 +195,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_speculative_steps=self.num_speculative_steps, num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
device=self.device, device=self.device,
cache_draft_logits=not use_strict_rejection_sampling,
) )
self.input_buffers = InputBuffers( self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
...@@ -446,7 +445,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -446,7 +445,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_prefill_tokens=self.req_states.next_prefill_tokens, next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu, temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu, seeds=self.sampler.sampling_states.seeds.gpu,
draft_logits_out=self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True, dummy_run=True,
skip_attn_for_dummy_run=skip_attn, skip_attn_for_dummy_run=skip_attn,
...@@ -815,11 +813,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -815,11 +813,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
# Rejection sampling for spec decoding. # Rejection sampling for spec decoding.
assert self.rejection_sampler is not None assert self.rejection_sampler is not None
assert self.speculator is not None
sampler_output = self.rejection_sampler( sampler_output = self.rejection_sampler(
logits, logits,
input_batch, input_batch,
# Draft logits are needed for probabilistic rejection sampling. # Draft logits are needed for probabilistic rejection sampling.
self.req_states.draft_logits, self.speculator.draft_logits,
) )
# Get the number of sampled and rejected tokens. # Get the number of sampled and rejected tokens.
...@@ -1145,7 +1144,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1145,7 +1144,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens, self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu, self.sampler.sampling_states.seeds.gpu,
self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
) )
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
......
...@@ -76,6 +76,17 @@ class EagleSpeculator: ...@@ -76,6 +76,17 @@ class EagleSpeculator:
device=device, device=device,
) )
cache_draft_logits = self.speculative_config.rejection_sample_method != "strict"
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=torch.float32,
device=device,
)
# currently we don't support PIECEWISE for Eagle. # currently we don't support PIECEWISE for Eagle.
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL: if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
...@@ -158,7 +169,6 @@ class EagleSpeculator: ...@@ -158,7 +169,6 @@ class EagleSpeculator:
slot_mappings: dict[str, torch.Tensor] | None, slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
draft_logits_out: torch.Tensor | None = None,
) -> None: ) -> None:
pos = self.input_buffers.positions[:num_reqs] pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
...@@ -185,8 +195,8 @@ class EagleSpeculator: ...@@ -185,8 +195,8 @@ class EagleSpeculator:
self.seeds, self.seeds,
pos + 1, pos + 1,
apply_temperature=True, apply_temperature=True,
processed_logits_out=draft_logits_out[:, step] processed_logits_out=self.draft_logits[:, step]
if draft_logits_out is not None if self.draft_logits is not None
else None, else None,
) )
self.draft_tokens[:num_reqs, step] = draft_tokens self.draft_tokens[:num_reqs, step] = draft_tokens
...@@ -241,8 +251,6 @@ class EagleSpeculator: ...@@ -241,8 +251,6 @@ class EagleSpeculator:
temperature: torch.Tensor, temperature: torch.Tensor,
# [max_num_reqs] # [max_num_reqs]
seeds: torch.Tensor, seeds: torch.Tensor,
# [max_num_reqs, num_speculative_steps, vocab_size]
draft_logits_out: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor | None = None, num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False, dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False, skip_attn_for_dummy_run: bool = False,
...@@ -308,8 +316,8 @@ class EagleSpeculator: ...@@ -308,8 +316,8 @@ class EagleSpeculator:
self.seeds, self.seeds,
pos + 1, pos + 1,
apply_temperature=True, apply_temperature=True,
processed_logits_out=draft_logits_out[:, 0] processed_logits_out=self.draft_logits[:, 0]
if draft_logits_out is not None if self.draft_logits is not None
else None, else None,
) )
...@@ -394,7 +402,6 @@ class EagleSpeculator: ...@@ -394,7 +402,6 @@ class EagleSpeculator:
slot_mappings_updated, slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode, cudagraph_runtime_mode=batch_desc.cg_mode,
draft_logits_out=draft_logits_out,
) )
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
......
...@@ -15,7 +15,6 @@ class RequestState: ...@@ -15,7 +15,6 @@ class RequestState:
num_speculative_steps: int, num_speculative_steps: int,
vocab_size: int, vocab_size: int,
device: torch.device, device: torch.device,
cache_draft_logits: bool,
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
...@@ -71,18 +70,6 @@ class RequestState: ...@@ -71,18 +70,6 @@ class RequestState:
dtype=torch.int64, dtype=torch.int64,
device=device, device=device,
) )
# Draft token logits.
# NOTE: This tensor maintains the "processed" logits after applying temperature,
# top-p, etc.
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=torch.float32,
device=device,
)
self.next_prefill_tokens = torch.zeros( self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device self.max_num_reqs, dtype=torch.int32, device=device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment