Unverified Commit 876a16f4 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[ModelRunner V2] Fix spec decoding + logprobs (#33391)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent aaa901ad
...@@ -335,6 +335,7 @@ def _validate_logprobs( ...@@ -335,6 +335,7 @@ def _validate_logprobs(
ref_prompt_logprob_toks, ref_prompt_logprob_toks,
ref_prompt_logprob_vals, ref_prompt_logprob_vals,
ref_prompt_token_ranks, ref_prompt_token_ranks,
_,
) = ref_prompt_logprobs ) = ref_prompt_logprobs
for idx, (prompt_token, pos_logprob_dict) in enumerate( for idx, (prompt_token, pos_logprob_dict) in enumerate(
zip(prompt_token_ids[1:], prompt_logprobs[1:]) zip(prompt_token_ids[1:], prompt_logprobs[1:])
......
...@@ -130,7 +130,7 @@ class LogprobsProcessor: ...@@ -130,7 +130,7 @@ class LogprobsProcessor:
assert self.num_prompt_logprobs is not None assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors token_ids, logprobs, ranks, _ = prompt_logprobs_tensors
# Recover shapes. # Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape num_prompt_tokens, num_logprobs = logprobs.shape
......
...@@ -51,13 +51,17 @@ class LogprobsTensors(NamedTuple): ...@@ -51,13 +51,17 @@ class LogprobsTensors(NamedTuple):
logprobs: torch.Tensor logprobs: torch.Tensor
# [num_reqs x num_generated_tokens] # [num_reqs x num_generated_tokens]
selected_token_ranks: torch.Tensor selected_token_ranks: torch.Tensor
# [num_reqs]
cu_num_generated_tokens: list[int] | None = None
def tolists(self, cu_num_generated_tokens: list[int] | None = None): def tolists(self, cu_num_generated_tokens: list[int] | None = None):
return LogprobsLists( return LogprobsLists(
self.logprob_token_ids.cpu().numpy(), self.logprob_token_ids.cpu().numpy(),
self.logprobs.cpu().numpy(), self.logprobs.cpu().numpy(),
self.selected_token_ranks.cpu().numpy(), self.selected_token_ranks.cpu().numpy(),
cu_num_generated_tokens, cu_num_generated_tokens
if cu_num_generated_tokens is not None
else self.cu_num_generated_tokens,
) )
def to_cpu_nonblocking(self) -> "LogprobsTensors": def to_cpu_nonblocking(self) -> "LogprobsTensors":
...@@ -67,10 +71,14 @@ class LogprobsTensors(NamedTuple): ...@@ -67,10 +71,14 @@ class LogprobsTensors(NamedTuple):
self.logprob_token_ids.to("cpu", non_blocking=True), self.logprob_token_ids.to("cpu", non_blocking=True),
self.logprobs.to("cpu", non_blocking=True), self.logprobs.to("cpu", non_blocking=True),
self.selected_token_ranks.to("cpu", non_blocking=True), self.selected_token_ranks.to("cpu", non_blocking=True),
self.cu_num_generated_tokens,
) )
def filter(self, mask: torch.Tensor) -> "LogprobsTensors": def filter(self, mask: torch.Tensor) -> "LogprobsTensors":
"""Filter the logprobs tensors with the given bool mask.""" """Filter the logprobs tensors with the given bool mask."""
assert self.cu_num_generated_tokens is None, (
"filter can't be used with cu_num_generated_tokens"
)
return LogprobsTensors( return LogprobsTensors(
self.logprob_token_ids[mask], self.logprob_token_ids[mask],
self.logprobs[mask], self.logprobs[mask],
......
...@@ -316,7 +316,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -316,7 +316,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): During the initial memory profiling, the sampler may skip # NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible # top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution. # during actual execution.
self.sampler(logits, idx_mapping, idx_mapping_np, pos) self.sampler(logits, idx_mapping, idx_mapping_np, idx_mapping_np, pos)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
...@@ -686,6 +686,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -686,6 +686,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits, logits,
input_batch.expanded_idx_mapping, input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np, input_batch.idx_mapping_np,
input_batch.cu_num_logits_np,
sample_pos, sample_pos,
) )
......
...@@ -103,6 +103,7 @@ def compute_topk_logprobs( ...@@ -103,6 +103,7 @@ def compute_topk_logprobs(
logits: torch.Tensor, logits: torch.Tensor,
num_logprobs: int, num_logprobs: int,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
cu_num_logits: list[int] | None = None,
) -> LogprobsTensors: ) -> LogprobsTensors:
assert num_logprobs >= 0 assert num_logprobs >= 0
batch_size, vocab_size = logits.shape batch_size, vocab_size = logits.shape
...@@ -135,4 +136,5 @@ def compute_topk_logprobs( ...@@ -135,4 +136,5 @@ def compute_topk_logprobs(
logprob_token_ids=logprob_token_ids, logprob_token_ids=logprob_token_ids,
logprobs=logprobs, logprobs=logprobs,
selected_token_ranks=token_ranks, selected_token_ranks=token_ranks,
cu_num_generated_tokens=cu_num_logits,
) )
...@@ -62,6 +62,7 @@ class Sampler: ...@@ -62,6 +62,7 @@ class Sampler:
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor, pos: torch.Tensor,
) -> SamplerOutput: ) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
...@@ -78,7 +79,11 @@ class Sampler: ...@@ -78,7 +79,11 @@ class Sampler:
if self.logprobs_mode == "processed_logprobs" if self.logprobs_mode == "processed_logprobs"
else logits else logits
) )
logprobs_tensors = compute_topk_logprobs(logits, max_num_logprobs, sampled) expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
logprobs_tensors = compute_topk_logprobs(
logits, max_num_logprobs, sampled, cu_num_logits
)
else: else:
logprobs_tensors = None logprobs_tensors = None
......
...@@ -4449,7 +4449,7 @@ class GPUModelRunner( ...@@ -4449,7 +4449,7 @@ class GPUModelRunner(
# Compute prompt logprobs. # Compute prompt logprobs.
logprobs = self.sampler.compute_logprobs(logits) logprobs = self.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.sampler.gather_logprobs( token_ids, logprobs, ranks, _ = self.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids logprobs, num_prompt_logprobs, tgt_token_ids
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment