Unverified Commit 5629f26d authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

[V1][Spec Decode] Change Spec Decode Rejection Sampling API (#13729)

parent 9ba28043
...@@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: ...@@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
temperature=torch.tensor([]), temperature=torch.tensor([]),
all_greedy=True, all_greedy=True,
all_random=False, all_random=False,
spec_token_ids=spec_tokens,
top_p=None, top_p=None,
top_k=None, top_k=None,
min_p=torch.empty(batch_size, ), min_p=torch.empty(batch_size, ),
...@@ -55,7 +54,7 @@ def test_perfect_match(sampler): ...@@ -55,7 +54,7 @@ def test_perfect_match(sampler):
metadata = create_sampling_metadata(spec_tokens) metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata) output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 3, 4]], expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
...@@ -70,7 +69,7 @@ def test_early_mismatch(sampler): ...@@ -70,7 +69,7 @@ def test_early_mismatch(sampler):
metadata = create_sampling_metadata(spec_tokens) metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata) output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
...@@ -85,7 +84,7 @@ def test_multiple_sequences(sampler): ...@@ -85,7 +84,7 @@ def test_multiple_sequences(sampler):
metadata = create_sampling_metadata(spec_tokens) metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata) output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
...@@ -100,7 +99,7 @@ def test_single_token_sequence(sampler): ...@@ -100,7 +99,7 @@ def test_single_token_sequence(sampler):
metadata = create_sampling_metadata(spec_tokens) metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata) output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected) assert torch.equal(output.sampled_token_ids, expected)
...@@ -113,7 +112,7 @@ def test_empty_sequence(sampler): ...@@ -113,7 +112,7 @@ def test_empty_sequence(sampler):
metadata = create_sampling_metadata(spec_tokens) metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata) output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected) assert torch.equal(output.sampled_token_ids, expected)
...@@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler): ...@@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler):
metadata = create_sampling_metadata(spec_tokens) metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata) output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID], expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], [4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int, dtype=torch.int,
...@@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): ...@@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
metadata = create_sampling_metadata(spec_tokens) metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata) output = sampler(spec_tokens, logits, metadata)
expected_tensor = torch.tensor(expected, expected_tensor = torch.tensor(expected,
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
...@@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler): ...@@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler):
metadata = create_sampling_metadata(spec_tokens) metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens, vocab_size) logits = create_logits_tensor(output_tokens, vocab_size)
output = sampler(logits, metadata) output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device) expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected) assert torch.equal(output.sampled_token_ids, expected)
assert logits.shape[-1] == vocab_size assert logits.shape[-1] == vocab_size
...@@ -105,7 +105,6 @@ def _create_default_sampling_metadata( ...@@ -105,7 +105,6 @@ def _create_default_sampling_metadata(
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device), vocab_size, device),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
spec_token_ids=None,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
......
...@@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata( ...@@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata(
dtype=torch.float, dtype=torch.float,
device=device), device=device),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
spec_token_ids=None,
min_tokens=min_tokens, min_tokens=min_tokens,
no_penalties=(all(x == 0 for x in presence_penalties) no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties) and all(x == 0 for x in frequency_penalties)
......
...@@ -13,9 +13,6 @@ class SamplingMetadata: ...@@ -13,9 +13,6 @@ class SamplingMetadata:
all_greedy: bool all_greedy: bool
all_random: bool all_random: bool
# None when there are no speculated tokens.
spec_token_ids: Optional[List[List[int]]]
top_p: Optional[torch.Tensor] top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor] top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor] min_p: Optional[torch.Tensor]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
...@@ -52,62 +54,62 @@ class RejectionSampler(nn.Module): ...@@ -52,62 +54,62 @@ class RejectionSampler(nn.Module):
else: else:
self.forward_method = self.forward_native self.forward_method = self.forward_native
def forward(self, logits: torch.Tensor, def forward(self, draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput: sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy: if not sampling_metadata.all_greedy:
raise NotImplementedError( raise NotImplementedError(
"Currently, only greedy sampling is supported by " "Currently, only greedy sampling is supported by "
"rejection sampler.") "rejection sampler.")
return self.forward_method(logits, sampling_metadata) return self.forward_method(draft_token_ids, target_probs,
sampling_metadata)
def flashinfer_sample( def flashinfer_sample(
self, self,
logits: torch.Tensor, draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
# NOTE: The following input preparationg can be moved # NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better # to the model runner with a persistent manner for better
# performance. # performance.
assert sampling_metadata.spec_token_ids is not None sample_lens = [len(x) + 1 for x in draft_token_ids]
spec_token_ids = sampling_metadata.spec_token_ids # Convert draft token IDs to a tensor, split by sample_lens, then pad.
max_spec_len = max(len(s) for s in spec_token_ids) draft_token_ids = [
batch_size = len(spec_token_ids) torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
draft_token_ids = torch.full((batch_size, max_spec_len), ]
INVALID_TOKEN_ID, draft_token_ids_tensor = pad_sequence(draft_token_ids,
device="cpu", batch_first=True,
dtype=torch.long) padding_value=INVALID_TOKEN_ID)
target_token_ids = torch.full((batch_size, max_spec_len + 1), if sampling_metadata.all_greedy:
fill_value=INVALID_TOKEN_ID, target_token_ids = target_probs.argmax(dim=-1).view(-1)
device=logits.device, target_token_ids = target_token_ids.split(sample_lens)
dtype=torch.long) target_token_ids = pad_sequence(target_token_ids,
batch_first=True,
# TODO: Vectorize the following loop for better performance. padding_value=INVALID_TOKEN_ID)
start_loc = 0
for i in range(batch_size): vocab_size = target_probs.size(-1)
num_spec_tokens = len(spec_token_ids[i]) # NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids[i, :num_spec_tokens] = torch.tensor( draft_token_ids_tensor = draft_token_ids_tensor.to(
spec_token_ids[i], device="cpu", dtype=torch.long) target_probs.device)
end_loc = start_loc + num_spec_tokens + 1 draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
# Assume greedy sampling. vocab_size,
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax( target_probs.device)
logits[start_loc:end_loc], dim=-1) target_probs = _create_greedy_token_probs(target_token_ids,
start_loc = end_loc vocab_size,
target_probs.device)
vocab_size = logits.size(-1) uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
# NOTE: CPU <-> GPU synchronization happens here. draft_token_ids_tensor.size(1) + 1,
draft_token_ids = draft_token_ids.to(logits.device) device=target_probs.device)
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size, else:
logits.device) raise NotImplementedError(
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size, "Currently, only greedy sampling is supported by "
logits.device) "rejection sampler.")
uniform_samples = torch.zeros(batch_size,
max_spec_len + 1,
device=logits.device)
sampled_token_ids, _, _ = fs.chain_speculative_sampling( sampled_token_ids, _, _ = fs.chain_speculative_sampling(
draft_probs, draft_probs,
draft_token_ids, draft_token_ids_tensor,
uniform_samples, uniform_samples,
target_probs, target_probs,
) )
...@@ -117,35 +119,35 @@ class RejectionSampler(nn.Module): ...@@ -117,35 +119,35 @@ class RejectionSampler(nn.Module):
# TODO: The following method can be optimized for better performance. # TODO: The following method can be optimized for better performance.
def forward_native( def forward_native(
self, self,
logits: torch.Tensor, draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
assert sampling_metadata.spec_token_ids is not None sample_lens = [len(x) + 1 for x in draft_token_ids]
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] # Convert draft token IDs to a tensor, split by sample_lens, then pad.
# Add 1 to include the 'bonus' token. draft_token_ids = [
sample_lens = [x + 1 for x in spec_lens] torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
output_token_ids = logits.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
spec_token_ids = [
torch.tensor(x,
dtype=output_token_ids.dtype,
device=output_token_ids.device)
for x in sampling_metadata.spec_token_ids
] ]
spec_token_ids = pad_sequence(spec_token_ids, draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True, batch_first=True,
padding_value=INVALID_TOKEN_ID) padding_value=INVALID_TOKEN_ID)
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
# Produce a mask that remains 1 (True) until the first # Add 1 to include the 'bonus' token.
# mismatch (cumprod turns 0 after a mismatch). if sampling_metadata.all_greedy:
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod( output_token_ids = target_probs.argmax(dim=-1).view(-1)
dim=1) output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (
output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
dim=1)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
# Identify valid positions (non-padding). # Identify valid positions (non-padding).
valid_mask = output_token_ids != INVALID_TOKEN_ID valid_mask = output_token_ids != INVALID_TOKEN_ID
# Generate mask with bonus token. # Generate mask with bonus token.
......
...@@ -9,7 +9,6 @@ from vllm.v1.sample.metadata import SamplingMetadata ...@@ -9,7 +9,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties, from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties) apply_min_token_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.rejection_sampler import RejectionSampler
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
...@@ -19,22 +18,12 @@ class Sampler(nn.Module): ...@@ -19,22 +18,12 @@ class Sampler(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.topk_topp_sampler = TopKTopPSampler() self.topk_topp_sampler = TopKTopPSampler()
self.rejection_sampler = RejectionSampler()
def forward( def forward(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
if sampling_metadata.spec_token_ids:
if sampling_metadata.max_num_logprobs:
raise NotImplementedError(
"Rejection sampling does not support logprobs.")
return self.rejection_sampler(
logits,
sampling_metadata,
)
# NOTE(woosuk): Use the original logits (before any penalties or # NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. # temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that # This is different from the V0 sampler, which uses the logits that
...@@ -127,6 +116,14 @@ class Sampler(nn.Module): ...@@ -127,6 +116,14 @@ class Sampler(nn.Module):
) )
return sampled return sampled
def compute_probs(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
if sampling_metadata.all_greedy:
return logits
# Apply temperature. This is an in-place op changing logits.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
return logits.softmax(dim=-1, dtype=torch.float32)
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32) return logits.log_softmax(dim=-1, dtype=torch.float32)
......
...@@ -490,23 +490,12 @@ class InputBatch: ...@@ -490,23 +490,12 @@ class InputBatch:
presence_penalties=self.presence_penalties[:num_reqs], presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(List[List[int]], self.req_output_token_ids), output_token_ids=cast(List[List[int]], self.req_output_token_ids),
spec_token_ids=None,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
no_penalties=self.no_penalties, no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs], logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask, allowed_token_ids_mask=allowed_token_ids_mask,
) )
def get_sampling_metadata(
self,
req_id_to_spec_token_ids: Dict[str, List[int]],
) -> SamplingMetadata:
# Set the new spec token ids in the cached sampling metadata.
self.sampling_metadata.spec_token_ids = [
req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids
] if req_id_to_spec_token_ids else None
return self.sampling_metadata
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty( prompt_token_ids_cpu_tensor = torch.empty(
......
...@@ -32,7 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, ...@@ -32,7 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -122,7 +122,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -122,7 +122,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_spec_decode = False self.use_spec_decode = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True self.use_spec_decode = True
self.rejection_sampler = RejectionSampler()
# TODO: find a better way to check if we are using ngram. # TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \ assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1." "Currently, only ngram spec decode is supported in V1."
...@@ -951,12 +951,24 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -951,12 +951,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
# Sample the next token and get logprobs if needed. # Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.get_sampling_metadata( sampling_metadata = self.input_batch.sampling_metadata
scheduler_output.scheduled_spec_decode_tokens) if not self.use_spec_decode:
sampler_output = self.model.sample( sampler_output = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
else:
target_probs = self.model.sampler.compute_probs(
logits, sampling_metadata)
scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
)
draft_token_ids = [
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
for req_id in scheduled_request_ids
]
sampler_output = self.rejection_sampler(draft_token_ids,
target_probs,
sampling_metadata)
# TODO(woosuk): The following loop can be slow since it iterates over # TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize. # the requests one by one. Optimize.
...@@ -1293,7 +1305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1293,7 +1305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
temperature=dummy_tensors(0.5), temperature=dummy_tensors(0.5),
all_greedy=False, all_greedy=False,
all_random=False, all_random=False,
spec_token_ids=None,
top_p=dummy_tensors(0.9), top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1), top_k=dummy_tensors(logits.size(1) - 1),
min_p=None, min_p=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment