Unverified Commit 99abb8b6 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1][Spec Decode] Optimize Rejection Sampler with Triton Kernels (#14930)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 3a1e6481
...@@ -6,20 +6,23 @@ import torch ...@@ -6,20 +6,23 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = "cpu" DEVICE = "cuda"
@pytest.fixture @pytest.fixture
def sampler(): def rejection_sampler():
return RejectionSampler() return RejectionSampler()
def create_logits_tensor(token_ids: list[list[int]], def create_logits_tensor(output_token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor: vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that """Helper function to create logits tensor that
will produce desired token ids on argmax""" will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids) num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
start_loc = 0 start_loc = 0
...@@ -32,14 +35,21 @@ def create_logits_tensor(token_ids: list[list[int]], ...@@ -32,14 +35,21 @@ def create_logits_tensor(token_ids: list[list[int]],
def create_sampling_metadata( def create_sampling_metadata(
all_greedy: bool, all_greedy: bool,
generators: Optional[dict[int, Any]] = None) -> SamplingMetadata: temperature: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set """Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling to the given value. Either all greedy or all random sampling
is used. is used.
""" """
generators = generators or {} generators = generators or {}
if all_greedy:
temperature = None
else:
assert temperature is not None
return SamplingMetadata( return SamplingMetadata(
temperature=torch.tensor([]), temperature=temperature,
all_greedy=all_greedy, all_greedy=all_greedy,
all_random=not all_greedy, all_random=not all_greedy,
top_p=None, top_p=None,
...@@ -61,7 +71,7 @@ def create_sampling_metadata( ...@@ -61,7 +71,7 @@ def create_sampling_metadata(
########################### Tests for Greedy Sampling ################### ########################### Tests for Greedy Sampling ###################
def test_perfect_match(sampler): def test_perfect_match(rejection_sampler):
"""Test when output tokens perfectly match speculated tokens""" """Test when output tokens perfectly match speculated tokens"""
spec_tokens = [[1, 2, 3]] spec_tokens = [[1, 2, 3]]
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
...@@ -70,15 +80,23 @@ def test_perfect_match(sampler): ...@@ -70,15 +80,23 @@ def test_perfect_match(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 3, 4]], expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_early_mismatch(sampler): def test_early_mismatch(rejection_sampler):
"""Test when there's an early mismatch in tokens""" """Test when there's an early mismatch in tokens"""
spec_tokens = [[1, 2, 3]] spec_tokens = [[1, 2, 3]]
output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1 output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1
...@@ -87,15 +105,25 @@ def test_early_mismatch(sampler): ...@@ -87,15 +105,25 @@ def test_early_mismatch(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device,
)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_multiple_sequences(sampler): def test_multiple_sequences(rejection_sampler):
"""Test handling multiple sequences of speculated tokens""" """Test handling multiple sequences of speculated tokens"""
spec_tokens = [[1, 2], [3]] spec_tokens = [[1, 2], [3]]
output_tokens = [[1, 2, 5], [3, output_tokens = [[1, 2, 5], [3,
...@@ -105,15 +133,23 @@ def test_multiple_sequences(sampler): ...@@ -105,15 +133,23 @@ def test_multiple_sequences(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_single_token_sequence(sampler): def test_single_token_sequence(rejection_sampler):
"""Test handling sequences with single token""" """Test handling sequences with single token"""
spec_tokens = [[1]] spec_tokens = [[1]]
output_tokens = [[1, 2]] # Single token with bonus token 2 output_tokens = [[1, 2]] # Single token with bonus token 2
...@@ -122,13 +158,21 @@ def test_single_token_sequence(sampler): ...@@ -122,13 +158,21 @@ def test_single_token_sequence(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_empty_sequence(sampler): def test_empty_sequence(rejection_sampler):
"""Test handling empty sequence of speculated tokens""" """Test handling empty sequence of speculated tokens"""
spec_tokens: list[list[int]] = [[]] spec_tokens: list[list[int]] = [[]]
output_tokens = [[5]] # Just the bonus token output_tokens = [[5]] # Just the bonus token
...@@ -137,13 +181,21 @@ def test_empty_sequence(sampler): ...@@ -137,13 +181,21 @@ def test_empty_sequence(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_multiple_mismatches(sampler): def test_multiple_mismatches(rejection_sampler):
"""Test handling multiple sequences with mismatches""" """Test handling multiple sequences with mismatches"""
spec_tokens = [[1, 2, 3], [4, 5, 6]] spec_tokens = [[1, 2, 3], [4, 5, 6]]
output_tokens = [[1, 2, 7, 6], [4, 8, 6, output_tokens = [[1, 2, 7, 6], [4, 8, 6,
...@@ -153,12 +205,22 @@ def test_multiple_mismatches(sampler): ...@@ -153,12 +205,22 @@ def test_multiple_mismatches(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID], spec_decode_metadata,
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 2, 7, PLACEHOLDER_TOKEN_ID],
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device,
)
assert torch.equal(output, expected) assert torch.equal(output, expected)
...@@ -166,18 +228,27 @@ def test_multiple_mismatches(sampler): ...@@ -166,18 +228,27 @@ def test_multiple_mismatches(sampler):
"spec_tokens,output_tokens,expected", "spec_tokens,output_tokens,expected",
[ [
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
([[1]], [[2, 3]], [[2, INVALID_TOKEN_ID]]), # First mismatch ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
[[1, 5, INVALID_TOKEN_ID], [3, 4, 7]]), # Mixed matches [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
]) ])
def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
expected):
"""Parametrized test for various matching scenarios""" """Parametrized test for various matching scenarios"""
metadata = create_sampling_metadata(all_greedy=True) metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected_tensor = torch.tensor(expected, expected_tensor = torch.tensor(expected,
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
...@@ -190,21 +261,31 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): ...@@ -190,21 +261,31 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
@pytest.mark.parametrize("batch_size", [1, 4, 8]) @pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) @pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20]) @pytest.mark.parametrize("n_rep", [20])
def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, def test_deterministic_when_seeded(
batch_size: int, frac_seeded: float, rejection_sampler,
n_rep: int): k: int,
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) vocab_size: int,
target_probs = torch.rand(batch_size * (k + 1), batch_size: int,
frac_seeded: float,
n_rep: int,
):
num_tokens = batch_size * k
draft_probs = torch.rand(num_tokens,
vocab_size, vocab_size,
dtype=torch.float32) dtype=torch.float32,
device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
target_logits = torch.rand_like(draft_probs)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64,
device=DEVICE)
draft_token_ids = torch.randint(low=0, draft_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k), size=(batch_size, k),
dtype=torch.int64) dtype=torch.int64,
device=DEVICE)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
...@@ -215,10 +296,21 @@ def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, ...@@ -215,10 +296,21 @@ def test_deterministic_when_seeded(sampler, k: int, vocab_size: int,
for i in range(batch_size) if seeded_mask[i] for i in range(batch_size) if seeded_mask[i]
} }
temperature = torch.ones(batch_size,
dtype=torch.float32,
device=DEVICE)
sampling_metadata = create_sampling_metadata(all_greedy=False, sampling_metadata = create_sampling_metadata(all_greedy=False,
temperature=temperature,
generators=seeded_seqs) generators=seeded_seqs)
rep_result = sampler(draft_token_ids.tolist(), draft_probs, spec_decode_metadata = SpecDecodeMetadata.make_dummy(
bonus_token_ids, target_probs, sampling_metadata) draft_token_ids.tolist(), device=DEVICE)
rep_result = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
results.append(rep_result) results.append(rep_result)
...@@ -257,10 +349,10 @@ def test_rejection_sampling_approximates_target_distribution(): ...@@ -257,10 +349,10 @@ def test_rejection_sampling_approximates_target_distribution():
num_reference_probs = 100 num_reference_probs = 100
# Prepare draft, target, and reference probability distributions # Prepare draft, target, and reference probability distributions
draft_probs, target_probs = (F.softmax( draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
torch.rand(vocab_size, dtype=torch.float32), dim=-1)
dim=-1, target_logits = torch.rand(vocab_size, dtype=torch.float32)
) for _ in range(2)) target_probs = F.softmax(target_logits, dim=-1)
reference_probs = F.softmax( reference_probs = F.softmax(
torch.rand(num_reference_probs, vocab_size, dtype=torch.float32), torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
dim=-1, dim=-1,
...@@ -273,7 +365,7 @@ def test_rejection_sampling_approximates_target_distribution(): ...@@ -273,7 +365,7 @@ def test_rejection_sampling_approximates_target_distribution():
for num_samples in sample_sizes: for num_samples in sample_sizes:
# Sample using rejection sampling. # Sample using rejection sampling.
rej_sample_probs = estimate_rejection_sampling_pdf( rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_probs, k, vocab_size, num_samples) draft_probs, target_logits, k, vocab_size, num_samples)
rej_sample_probs = rej_sample_probs.to(DEVICE) rej_sample_probs = rej_sample_probs.to(DEVICE)
# Average distance from reference probs. # Average distance from reference probs.
...@@ -313,7 +405,7 @@ def get_ratio_first_to_last(elements: list[float]) -> float: ...@@ -313,7 +405,7 @@ def get_ratio_first_to_last(elements: list[float]) -> float:
def estimate_rejection_sampling_pdf( def estimate_rejection_sampling_pdf(
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
target_probs: torch.Tensor, target_logits: torch.Tensor,
k: int, k: int,
vocab_size: int, vocab_size: int,
num_samples: int, num_samples: int,
...@@ -323,35 +415,44 @@ def estimate_rejection_sampling_pdf( ...@@ -323,35 +415,44 @@ def estimate_rejection_sampling_pdf(
Args: Args:
draft_probs: Draft probability distribution. draft_probs: Draft probability distribution.
target_probs: Target probability distribution. target_logits: Target logits.
num_samples: Number of samples to draw. num_samples: Number of samples to draw.
Returns: Returns:
Estimated probability distribution of the output tokens. Estimated probability distribution of the output tokens.
""" """
sampler = RejectionSampler() rejection_sampler = RejectionSampler()
# Repeat draft probs num_samples times. num_tokens = num_samples * k
# Repeat draft probs num_samples * k times.
draft_probs = draft_probs.reshape(1, 1, draft_probs = draft_probs.reshape(1, 1,
vocab_size).repeat(num_samples, k, 1) vocab_size).repeat(num_samples, k, 1)
# Repeat target probs num_samples * (k + 1) times. # Repeat target probs num_tokens times.
target_probs = target_probs.reshape(1, 1, vocab_size).repeat( target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
num_samples, k + 1, 1).reshape(num_samples * (k + 1), vocab_size)
# Randomly sample draft token ids from draft probs. # Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :], draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=k, num_samples=k,
replacement=True).reshape( replacement=True).reshape(
num_samples, k) num_samples, k)
draft_probs = draft_probs.view(num_tokens, vocab_size)
# Bonus tokens not used but required. # Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
device=DEVICE).repeat(num_samples, 1) device=DEVICE).repeat(num_samples, 1)
sampling_metadata = create_sampling_metadata(all_greedy=False) temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
output_token_ids = sampler(draft_token_ids.tolist(), draft_probs, sampling_metadata = create_sampling_metadata(all_greedy=False,
bonus_token_ids, target_probs, temperature=temperature)
sampling_metadata) spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=bonus_token_ids.device)
output_token_ids = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
output_token_ids = output_token_ids[:, :-1].flatten() output_token_ids = output_token_ids[:, :-1].flatten()
hist = torch.histogram(output_token_ids.to(dtype=torch.float, hist = torch.histogram(output_token_ids.to(dtype=torch.float,
......
...@@ -35,7 +35,6 @@ if TYPE_CHECKING: ...@@ -35,7 +35,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
......
...@@ -46,7 +46,7 @@ class SamplerOutput: ...@@ -46,7 +46,7 @@ class SamplerOutput:
# [num_reqs, max_num_generated_tokens] # [num_reqs, max_num_generated_tokens]
# Different requests can have different number of generated tokens. # Different requests can have different number of generated tokens.
# All requests are padded to max_num_generated_tokens. # All requests are padded to max_num_generated_tokens.
# INVALID_TOKEN_ID (-1 by default) is used for padding. # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids: torch.Tensor sampled_token_ids: torch.Tensor
logprobs_tensors: Optional[LogprobsTensors] logprobs_tensors: Optional[LogprobsTensors]
......
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch
def compiled_softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor] = 1.0,
) -> torch.Tensor:
"""Faster softmax kernel generated by torch.compile.
Args:
logits: [n, vocab_size]
temperature: [n] or float
"""
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
torch._dynamo.mark_dynamic(logits, index=0)
if isinstance(temperature, torch.Tensor):
torch._dynamo.mark_dynamic(temperature, index=0)
return _softmax(logits, temperature)
@torch.compile
def _softmax(
logits: torch.Tensor,
temperature: Union[float, torch.Tensor],
) -> torch.Tensor:
logits = logits / temperature
return torch.softmax(logits, dim=-1, dtype=torch.float32)
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import numpy as np
import torch
@dataclass
class SpecDecodeMetadata:
# [num_tokens]
draft_token_ids: torch.Tensor
# [batch_size]
num_draft_tokens: list[int]
# [batch_size]
cu_num_draft_tokens: torch.Tensor
# [num_tokens]
target_logits_indices: torch.Tensor
# [batch_size]
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
@classmethod
def make_dummy(
cls,
draft_token_ids: list[list[int]],
device: torch.device,
) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids)
draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
dtype=torch.int32,
device=device)
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
device)
target_logits_indices = torch.zeros(num_tokens,
dtype=torch.int32,
device=device)
bonus_logits_indices = torch.zeros(batch_size,
dtype=torch.int32,
device=device)
logits_indices = torch.zeros(num_tokens + batch_size,
dtype=torch.int32,
device=device)
return cls(
draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
......
...@@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, ...@@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
...@@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False self.use_spec_decode = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True self.use_spec_decode = True
self.rejection_sampler = RejectionSampler()
# TODO: find a better way to check if we are using ngram. # TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \ assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1." "Currently, only ngram spec decode is supported in V1."
...@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens, self.speculative_config.num_speculative_tokens,
) )
self.rejection_sampler = RejectionSampler()
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
...@@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> tuple[FlashAttentionMetadata, torch.Tensor]: ) -> tuple[FlashAttentionMetadata, torch.Tensor,
Optional[SpecDecodeMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
...@@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_spec_decode = len( use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0 scheduler_output.scheduled_spec_decode_tokens) > 0
if use_spec_decode: if not use_spec_decode:
logits_indices = self._calc_spec_decode_metadata(
scheduler_output, cu_num_tokens)
else:
# NOTE(woosuk): Due to chunked prefills, the batch may contain # NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token # partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity. # from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests. # We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
logits_indices = attn_metadata.query_start_loc[1:] - 1 logits_indices = attn_metadata.query_start_loc[1:] - 1
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices return attn_metadata, logits_indices, spec_decode_metadata
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
...@@ -732,50 +745,79 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -732,50 +745,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _calc_spec_decode_metadata( def _calc_spec_decode_metadata(
self, self,
scheduler_output: "SchedulerOutput", num_draft_tokens: np.ndarray,
cu_num_tokens: np.ndarray, cu_num_scheduled_tokens: np.ndarray,
) -> torch.Tensor: ) -> SpecDecodeMetadata:
# Get the number of spec decode tokens for each request. # Inputs:
num_reqs = self.input_batch.num_reqs # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) # num_draft_tokens: [ 3, 0, 2, 0, 1]
for i, req_id in enumerate(self.input_batch.req_ids): # Outputs:
num_spec_decode_tokens[i] = len( # cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# Get spec decode logits indices. # target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] # bonus_logits_indices: [ 3, 4, 7, 8, 10]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1] # Compute the logits indices.
# num_sampled_tokens: [4, 1, 3, 1, 2] # [4, 1, 3, 1, 2]
# spec_decode_logits_indices: num_sampled_tokens = num_draft_tokens + 1
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] # Step 1. [4, 5, 8, 9, 11]
num_sampled_tokens = num_spec_decode_tokens + 1 cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
# logits_start_loc: [0, 103, 104, 206, 207] total_num_sampled_tokens = cu_num_sampled_tokens[-1]
logits_start_loc = cu_num_tokens - num_sampled_tokens # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# [0, 103, 104, 206, 207] -> cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] num_sampled_tokens)
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# The following three lines: arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] logits_indices = np.repeat(
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] logits_indices += arange
cumsums_sampled_offsets = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) # Compute the bonus logits indices.
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] bonus_logits_indices = cu_num_sampled_tokens - 1
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] # Compute the draft logits indices.
total_num_sampled_tokens = num_sampled_tokens.sum() # [3, 3, 5, 5, 6]
sampled_arange = (self.arange_np[:total_num_sampled_tokens] - cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cumsums_sampled_offsets) total_num_draft_tokens = cu_num_draft_tokens[-1]
# [0, 0, 0, 3, 3, 5]
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] num_draft_tokens)
spec_decode_logits_indices = logits_start_loc + sampled_arange # [0, 1, 2, 0, 1, 0]
return torch.from_numpy(spec_decode_logits_indices).to( arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
# [0, 0, 0, 5, 5, 9]
target_logits_indices = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True) self.device, non_blocking=True)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
return metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
...@@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = [] encoder_outputs = []
# Prepare the decoder inputs. # Prepare the decoder inputs.
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
...@@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed. # Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata sampling_metadata = self.input_batch.sampling_metadata
if not self.use_spec_decode: if spec_decode_metadata is None:
sampler_output = self.model.sample( sampler_output = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
else: else:
draft_token_ids = [ # TODO(woosuk): Optimize the memory usage.
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
for req_id in self.input_batch.req_ids
]
sample_lens = [len(tokens) + 1 for tokens in draft_token_ids]
recover_logits_idx = np.cumsum(sample_lens) - 1
target_probs = self.rejection_sampler.compute_probs(
logits, sampling_metadata, sample_lens)
sampler_output = self.model.sample( sampler_output = self.model.sample(
logits=logits[recover_logits_idx, :], logits=bonus_logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
bonus_token_ids = sampler_output.sampled_token_ids bonus_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): Optimize the memory usage.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler( output_token_ids = self.rejection_sampler(
draft_token_ids, spec_decode_metadata,
None, # draft_probs None, # draft_probs
target_logits,
bonus_token_ids, bonus_token_ids,
target_probs, sampling_metadata,
sampling_metadata) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over # TODO(woosuk): The following loop can be slow since it iterates over
...@@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
valid_sampled_token_ids = sampled_token_ids.tolist() valid_sampled_token_ids = sampled_token_ids.tolist()
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
valid_mask = sampled_token_ids != INVALID_TOKEN_ID valid_sampled_token_ids = self.rejection_sampler.parse_output(
gen_lens = valid_mask.sum(dim=1).tolist() sampled_token_ids, self.input_batch.vocab_size)
# TODO(woosuk): Optimize this.
valid_sampled_token_ids = [
seq.tolist()
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]
if not self.use_spec_decode: if not self.use_spec_decode:
spec_token_ids = None spec_token_ids = None
...@@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"initializing the engine.") from e "initializing the engine.") from e
else: else:
raise e raise e
if self.use_spec_decode:
draft_token_ids = [[0] for _ in range(num_reqs)]
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids, self.device)
num_tokens = sum(len(ids) for ids in draft_token_ids)
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs = None
target_logits = torch.randn(num_tokens,
logits.shape[-1],
device=self.device,
dtype=logits.dtype)
# NOTE(woosuk): Here, we should use int32 because the sampler uses
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
# will occur at runtime.
bonus_token_ids = torch.zeros(num_reqs,
device=self.device,
dtype=torch.int32)
self.rejection_sampler(
dummy_spec_decode_metadata,
draft_probs,
target_logits,
bonus_token_ids,
dummy_metadata,
)
return sampler_output return sampler_output
def profile_run(self) -> None: def profile_run(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment