Unverified Commit 876a16f4 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[ModelRunner V2] Fix spec decoding + logprobs (#33391)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent aaa901ad
......@@ -335,6 +335,7 @@ def _validate_logprobs(
ref_prompt_logprob_toks,
ref_prompt_logprob_vals,
ref_prompt_token_ranks,
_,
) = ref_prompt_logprobs
for idx, (prompt_token, pos_logprob_dict) in enumerate(
zip(prompt_token_ids[1:], prompt_logprobs[1:])
......
......@@ -130,7 +130,7 @@ class LogprobsProcessor:
assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors
token_ids, logprobs, ranks, _ = prompt_logprobs_tensors
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
......
......@@ -51,13 +51,17 @@ class LogprobsTensors(NamedTuple):
logprobs: torch.Tensor
# [num_reqs x num_generated_tokens]
selected_token_ranks: torch.Tensor
# [num_reqs]
cu_num_generated_tokens: list[int] | None = None
def tolists(self, cu_num_generated_tokens: list[int] | None = None):
return LogprobsLists(
self.logprob_token_ids.cpu().numpy(),
self.logprobs.cpu().numpy(),
self.selected_token_ranks.cpu().numpy(),
cu_num_generated_tokens,
cu_num_generated_tokens
if cu_num_generated_tokens is not None
else self.cu_num_generated_tokens,
)
def to_cpu_nonblocking(self) -> "LogprobsTensors":
......@@ -67,10 +71,14 @@ class LogprobsTensors(NamedTuple):
self.logprob_token_ids.to("cpu", non_blocking=True),
self.logprobs.to("cpu", non_blocking=True),
self.selected_token_ranks.to("cpu", non_blocking=True),
self.cu_num_generated_tokens,
)
def filter(self, mask: torch.Tensor) -> "LogprobsTensors":
"""Filter the logprobs tensors with the given bool mask."""
assert self.cu_num_generated_tokens is None, (
"filter can't be used with cu_num_generated_tokens"
)
return LogprobsTensors(
self.logprob_token_ids[mask],
self.logprobs[mask],
......
......@@ -316,7 +316,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(logits, idx_mapping, idx_mapping_np, pos)
self.sampler(logits, idx_mapping, idx_mapping_np, idx_mapping_np, pos)
@torch.inference_mode()
def profile_run(self) -> None:
......@@ -686,6 +686,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
input_batch.cu_num_logits_np,
sample_pos,
)
......
......@@ -103,6 +103,7 @@ def compute_topk_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
cu_num_logits: list[int] | None = None,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
......@@ -135,4 +136,5 @@ def compute_topk_logprobs(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
cu_num_generated_tokens=cu_num_logits,
)
......@@ -62,6 +62,7 @@ class Sampler:
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
......@@ -78,7 +79,11 @@ class Sampler:
if self.logprobs_mode == "processed_logprobs"
else logits
)
logprobs_tensors = compute_topk_logprobs(logits, max_num_logprobs, sampled)
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
logprobs_tensors = compute_topk_logprobs(
logits, max_num_logprobs, sampled, cu_num_logits
)
else:
logprobs_tensors = None
......
......@@ -4449,7 +4449,7 @@ class GPUModelRunner(
# Compute prompt logprobs.
logprobs = self.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
token_ids, logprobs, ranks, _ = self.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment