Unverified Commit f205c098 authored by Jonas M. Kübler's avatar Jonas M. Kübler Committed by GitHub
Browse files

[Bugfix] Unify rank computation across regular decoding and speculative decoding (#7899)

parent ef99a787
...@@ -4,10 +4,12 @@ import pytest ...@@ -4,10 +4,12 @@ import pytest
import torch import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import _get_ranks
from vllm.model_executor.layers.typical_acceptance_sampler import ( from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler) TypicalAcceptanceSampler)
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import split_batch_by_proposal_len from vllm.spec_decode.util import (get_sampled_token_logprobs,
split_batch_by_proposal_len)
def test_get_all_seq_ids(): def test_get_all_seq_ids():
...@@ -126,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method): ...@@ -126,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method):
return sampler return sampler
else: else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
def test_get_sampled_token_logprobs():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor = torch.tensor(
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
sampled_token_tensor = torch.tensor([[1,
0]]) # shape (num_steps, batch_size)
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
sampled_token_tensor)
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
sampled_token_tensor.reshape(-1))
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
...@@ -43,8 +43,8 @@ def get_sampled_token_logprobs( ...@@ -43,8 +43,8 @@ def get_sampled_token_logprobs(
sampled_token_ids, ] sampled_token_ids, ]
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
-1, -1, vocab_size) -1, -1, vocab_size)
sampled_token_ids_ranks = (logprob_tensor >= sampled_token_ids_ranks = (logprob_tensor >
expanded_selected_logprobs).sum(-1) expanded_selected_logprobs).sum(-1).add_(1)
return sampled_token_ids_ranks, selected_logprobs return sampled_token_ids_ranks, selected_logprobs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment