Unverified Commit 626daa20 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Feat] Unified Synthetic Acceptance Rate for V1 and V2 (#40662)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
Signed-off-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent fe85a92e
...@@ -1310,6 +1310,54 @@ def test_dflash_acceptance_rates(dflash_config): ...@@ -1310,6 +1310,54 @@ def test_dflash_acceptance_rates(dflash_config):
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@single_gpu_only
def test_synthetic_acceptance_rate():
"""Verify that synthetic rejection sampling produces an acceptance
length close to the requested mean acceptance length."""
num_spec_tokens = 3
expected_acceptance_len = 1.875
tolerance = 0.15
spec_llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
trust_remote_code=True,
speculative_config={
"method": "eagle3",
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
"num_speculative_tokens": num_spec_tokens,
"max_model_len": 2048,
"rejection_sample_method": "synthetic",
"synthetic_acceptance_length": expected_acceptance_len,
},
max_model_len=2048,
enforce_eager=True,
disable_log_stats=False,
)
test_prompts = get_test_prompts(mm_enabled=False, num_prompts=50)
spec_llm.chat(
test_prompts,
SamplingParams(temperature=0, max_tokens=64, ignore_eos=True),
)
metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(metrics)
print(
f"Synthetic acceptance length: {acceptance_len:.3f}"
f" (expected={expected_acceptance_len:.3f},"
f" tolerance=±{tolerance})"
)
assert abs(acceptance_len - expected_acceptance_len) <= tolerance, (
f"Synthetic acceptance length {acceptance_len:.3f} is not within"
f" ±{tolerance} of expected {expected_acceptance_len:.3f}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
def test_dflash_correctness(dflash_config): def test_dflash_correctness(dflash_config):
""" """
E2E test for DFlash (block diffusion) speculative decoding. E2E test for DFlash (block diffusion) speculative decoding.
......
...@@ -933,3 +933,64 @@ def test_sample_recovered_tokens( ...@@ -933,3 +933,64 @@ def test_sample_recovered_tokens(
device=DEVICE_TYPE, device=DEVICE_TYPE,
) )
assert torch.equal(recovered_token_ids, ref_recovered_token_ids) assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
########################### Tests for Synthetic Rejection Sampling #########
def _make_synthetic_sampler(rates: list[float]) -> RejectionSampler:
mock_sampler = Mock(spec=Sampler)
mock_sampler.logprobs_mode = "raw_logprobs"
spec_config = Mock()
spec_config.rejection_sample_method = "synthetic"
spec_config.synthetic_acceptance_rates = rates
return RejectionSampler(mock_sampler, spec_config, torch.device(DEVICE_TYPE))
def _make_sampling_metadata(all_greedy: bool) -> SamplingMetadata:
temperature = None if all_greedy else torch.tensor([1.0, 1.0], device=DEVICE_TYPE)
return create_sampling_metadata(all_greedy=all_greedy, temperature=temperature)
@pytest.mark.parametrize("all_greedy", [True, False])
def test_synthetic_all_accepted(all_greedy: bool):
"""With all rates=1.0, every draft token is accepted."""
sampler = _make_synthetic_sampler([1.0, 1.0])
spec_tokens = [[1, 2], [3]]
output_tokens = [[10, 20, 50], [30, 40]]
metadata = _make_sampling_metadata(all_greedy)
logits = create_logits_tensor(output_tokens)
bonus = torch.tensor([50, 40], device=DEVICE_TYPE)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(sampler, bonus)
output = sampler(spec_decode_metadata, None, logits, metadata)
expected = torch.tensor(
[[1, 2, 50], [3, 40, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=DEVICE_TYPE,
)
assert torch.equal(output.sampled_token_ids, expected)
@pytest.mark.parametrize("all_greedy", [True, False])
def test_synthetic_all_rejected(all_greedy: bool):
"""With all rates=0.0, the first token is always rejected."""
sampler = _make_synthetic_sampler([0.0, 0.0])
spec_tokens = [[1, 2], [3]]
output_tokens = [[10, 20, 50], [30, 40]]
metadata = _make_sampling_metadata(all_greedy)
logits = create_logits_tensor(output_tokens)
bonus = torch.tensor([50, 40], device=DEVICE_TYPE)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(sampler, bonus)
output = sampler(spec_decode_metadata, None, logits, metadata)
result = output.sampled_token_ids
# Exactly one token emitted per sequence (the rejection fallback),
# followed by placeholders.
for row in result:
assert row[0] != PLACEHOLDER_TOKEN_ID
assert (row[1:] == PLACEHOLDER_TOKEN_ID).all()
...@@ -2,33 +2,48 @@ ...@@ -2,33 +2,48 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import ( from vllm.config.speculative import SpeculativeConfig
compute_synthetic_rejection_sampler_params, from vllm.v1.spec_decode.utils import unconditional_to_conditional_rates
def test_unconditional_to_conditional_rates_basic():
# c_0 = p_0; c_i = p_i / p_{i-1}
assert unconditional_to_conditional_rates([0.9, 0.5, 0.2]) == pytest.approx(
[0.9, 0.5 / 0.9, 0.2 / 0.5]
)
def test_unconditional_to_conditional_rates_handles_zero():
# After a zero, subsequent conditional rates are clamped to 0 (the chain
# has already terminated in the kernel, so these values are unused).
assert unconditional_to_conditional_rates([1.0, 0.6, 0.0, 0.0]) == pytest.approx(
[1.0, 0.6, 0.0, 0.0]
)
def test_unconditional_to_conditional_rates_all_ones():
assert unconditional_to_conditional_rates([1.0, 1.0, 1.0]) == pytest.approx(
[1.0, 1.0, 1.0]
)
@pytest.mark.parametrize(
"length,n,expected",
[
(2.6, 3, [1.0, 0.6, 0.0]),
(1.0, 3, [0.0, 0.0, 0.0]),
(4.0, 3, [1.0, 1.0, 1.0]),
(2.0, 3, [1.0, 0.0, 0.0]),
(3.5, 4, [1.0, 1.0, 0.5, 0.0]),
],
) )
def test_acceptance_length_to_rates(length, n, expected):
assert SpeculativeConfig._acceptance_length_to_rates(length, n) == pytest.approx(
expected
)
NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10] def test_resolve_length_produces_minvariance_schedule():
ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)] assert SpeculativeConfig._resolve_synthetic_acceptance_rates(
3, None, 2.6
) == pytest.approx([1.0, 0.6, 0.0])
@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS)
def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int):
"""Test that the base acceptance rate and decay factor generated for
synthetic rejection sampling have a mean joint acceptance probability
that matches the desired acceptance rate."""
tol = 1e-9
for desired_acceptance_rate in ACCEPTANCE_RATES:
base_rate, decay_factor = compute_synthetic_rejection_sampler_params(
desired_acceptance_rate, num_speculative_steps, tol=tol
)
# Compute the mean of joint acceptance probabilities across
# all speculative positions.
joint_prob = 1.0
mean_joint = 0.0
for i in range(num_speculative_steps):
joint_prob *= base_rate * decay_factor**i
mean_joint += joint_prob
mean_joint /= num_speculative_steps
assert abs(desired_acceptance_rate - mean_joint) < 10 * tol
assert base_rate <= 1.0
...@@ -189,12 +189,64 @@ class SpeculativeConfig: ...@@ -189,12 +189,64 @@ class SpeculativeConfig:
distribution, but the latter yields a higher acceptance rate at the cost distribution, but the latter yields a higher acceptance rate at the cost
of more memory to cache draft logits.""" of more memory to cache draft logits."""
synthetic_acceptance_rate: float | None = None synthetic_acceptance_rates: list[float] | None = None
"""Average acceptance rate for synthetic rejection sampling. Draft """Per-position *unconditional* acceptance rates for synthetic rejection
tokens are accepted with a position-dependent probability that decays sampling. Position i's entry is the marginal probability that the first
geometrically, calibrated so that the mean rate across all speculative i+1 draft tokens are all accepted; the list must have length
positions equals this value. Only used when rejection_sample_method num_speculative_tokens, each entry in [0, 1], and be monotonically
is 'synthetic'. Must be in [0, 1].""" non-increasing. Only valid when rejection_sample_method is 'synthetic'.
Mutually exclusive with synthetic_acceptance_length."""
synthetic_acceptance_length: float | None = None
"""Target mean acceptance length for synthetic rejection sampling, in
[1, num_speculative_tokens + 1]. Resolved internally to
synthetic_acceptance_rates. Only valid when rejection_sample_method is 'synthetic'.
Mutually exclusive with synthetic_acceptance_rates."""
@staticmethod
def _acceptance_length_to_rates(length: float, n: int) -> list[float]:
"""Mean acceptance length to unconditional per-position rates, using
the minimum-variance schedule."""
num_drafts = length - 1 # expected number of accepted draft tokens
num_full = int(num_drafts)
return (
[1.0] * num_full + [num_drafts - num_full] + [0.0] * (n - num_full - 1)
)[:n]
@staticmethod
def _resolve_synthetic_acceptance_rates(
n: int,
rates: list[float] | None,
length: float | None,
) -> list[float]:
"""Return per-position unconditional acceptance rates from exactly one
of `rates` or `length` (validates range, length, and monotonicity)."""
if (rates is None) == (length is None):
raise ValueError(
"rejection_sample_method='synthetic' requires exactly one of "
"synthetic_acceptance_rates or synthetic_acceptance_length."
)
if rates is not None:
if len(rates) != n:
raise ValueError(
f"synthetic_acceptance_rates must have length {n}, got {rates}."
)
if not all(0.0 <= r <= 1.0 for r in rates):
raise ValueError(
f"synthetic_acceptance_rates entries must be in [0, 1], "
f"got {rates}."
)
if any(rates[i] > rates[i - 1] for i in range(1, n)):
raise ValueError(
f"synthetic_acceptance_rates must be non-increasing, got {rates}."
)
return list(rates)
assert length is not None
if not 1.0 <= length <= float(n + 1):
raise ValueError(
f"synthetic_acceptance_length must be in [1, {n + 1}], got {length}."
)
return SpeculativeConfig._acceptance_length_to_rates(length, n)
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
...@@ -818,6 +870,23 @@ class SpeculativeConfig: ...@@ -818,6 +870,23 @@ class SpeculativeConfig:
f"than zero ({self.num_speculative_tokens})." f"than zero ({self.num_speculative_tokens})."
) )
if self.rejection_sample_method == "synthetic":
# Consolidate to per-position rates
self.synthetic_acceptance_rates = self._resolve_synthetic_acceptance_rates(
self.num_speculative_tokens,
self.synthetic_acceptance_rates,
self.synthetic_acceptance_length,
)
self.synthetic_acceptance_length = None
elif (
self.synthetic_acceptance_rates is not None
or self.synthetic_acceptance_length is not None
):
raise ValueError(
"synthetic_acceptance_rates / synthetic_acceptance_length "
"are only valid with rejection_sample_method='synthetic'."
)
if self.draft_model_config: if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config( self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config self.draft_parallel_config
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import replace from dataclasses import replace
from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -17,6 +20,10 @@ from vllm.v1.sample.ops.penalties import apply_all_penalties ...@@ -17,6 +20,10 @@ from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.utils import unconditional_to_conditional_rates
if TYPE_CHECKING:
from vllm.config.speculative import SpeculativeConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -50,13 +57,33 @@ class RejectionSampler(nn.Module): ...@@ -50,13 +57,33 @@ class RejectionSampler(nn.Module):
output tokens = accepted tokens + recovered tokens + bonus tokens output tokens = accepted tokens + recovered tokens + bonus tokens
""" """
def __init__(self, sampler: Sampler): def __init__(
self,
sampler: Sampler,
spec_config: SpeculativeConfig | None = None,
device: torch.device | None = None,
):
super().__init__() super().__init__()
self.sampler = sampler self.sampler = sampler
logprobs_mode = self.sampler.logprobs_mode logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed") self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits") self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
self.synthetic_conditional_rates: torch.Tensor | None = None
if (
spec_config is not None
and spec_config.rejection_sample_method == "synthetic"
):
assert spec_config.synthetic_acceptance_rates is not None
self.synthetic_conditional_rates = torch.tensor(
unconditional_to_conditional_rates(
spec_config.synthetic_acceptance_rates
),
dtype=torch.float32,
device=device,
)
self.synthetic_mode = self.synthetic_conditional_rates is not None
def forward( def forward(
self, self,
metadata: SpecDecodeMetadata, metadata: SpecDecodeMetadata,
...@@ -147,6 +174,8 @@ class RejectionSampler(nn.Module): ...@@ -147,6 +174,8 @@ class RejectionSampler(nn.Module):
target_logits, target_logits,
bonus_token_ids, bonus_token_ids,
sampling_metadata, sampling_metadata,
synthetic_mode=self.synthetic_mode,
synthetic_conditional_rates=self.synthetic_conditional_rates,
) )
logprobs_tensors = None logprobs_tensors = None
...@@ -362,6 +391,8 @@ def rejection_sample( ...@@ -362,6 +391,8 @@ def rejection_sample(
# [batch_size, 1] # [batch_size, 1]
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
synthetic_mode: bool = False,
synthetic_conditional_rates: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert draft_token_ids.ndim == 1 assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2 assert draft_probs is None or draft_probs.ndim == 2
...@@ -389,6 +420,20 @@ def rejection_sample( ...@@ -389,6 +420,20 @@ def rejection_sample(
is_greedy = None is_greedy = None
else: else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
# Generate uniform probabilities before either kernel because synthetic
# mode needs them in the greedy kernel too. Skip only when all requests
# are greedy *and* synthetic mode is off (the standard fast-path).
# [num_tokens]
uniform_probs: torch.Tensor | None = None
if synthetic_mode or not sampling_metadata.all_greedy:
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
if not sampling_metadata.all_random: if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests. # Rejection sampling for greedy sampling requests.
target_argmax = target_logits.argmax(dim=-1) target_argmax = target_logits.argmax(dim=-1)
...@@ -400,6 +445,9 @@ def rejection_sample( ...@@ -400,6 +445,9 @@ def rejection_sample(
bonus_token_ids, bonus_token_ids,
is_greedy, is_greedy,
max_spec_len, max_spec_len,
uniform_probs,
synthetic_conditional_rates,
SYNTHETIC_MODE=synthetic_mode,
) )
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
return output_token_ids return output_token_ids
...@@ -408,15 +456,6 @@ def rejection_sample( ...@@ -408,15 +456,6 @@ def rejection_sample(
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
assert target_probs.is_contiguous() assert target_probs.is_contiguous()
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position. # Sample recovered tokens for each position.
# [num_tokens] # [num_tokens]
recovered_token_ids = sample_recovered_tokens( recovered_token_ids = sample_recovered_tokens(
...@@ -431,6 +470,7 @@ def rejection_sample( ...@@ -431,6 +470,7 @@ def rejection_sample(
) )
# Rejection sampling for random sampling requests. # Rejection sampling for random sampling requests.
assert uniform_probs is not None
rejection_random_sample_kernel[(batch_size,)]( rejection_random_sample_kernel[(batch_size,)](
output_token_ids, output_token_ids,
cu_num_draft_tokens, cu_num_draft_tokens,
...@@ -443,7 +483,9 @@ def rejection_sample( ...@@ -443,7 +483,9 @@ def rejection_sample(
is_greedy, is_greedy,
max_spec_len, max_spec_len,
vocab_size, vocab_size,
synthetic_conditional_rates,
NO_DRAFT_PROBS=draft_probs is None, NO_DRAFT_PROBS=draft_probs is None,
SYNTHETIC_MODE=synthetic_mode,
) )
return output_token_ids return output_token_ids
...@@ -658,6 +700,9 @@ def rejection_greedy_sample_kernel( ...@@ -658,6 +700,9 @@ def rejection_greedy_sample_kernel(
bonus_token_ids_ptr, # [batch_size] bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None is_greedy_ptr, # [batch_size] or None
max_spec_len, max_spec_len,
uniform_probs_ptr, # [num_tokens] or None (synthetic mode only)
synthetic_conditional_rates_ptr, # [num_speculative_tokens] or None
SYNTHETIC_MODE: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
...@@ -675,14 +720,20 @@ def rejection_greedy_sample_kernel( ...@@ -675,14 +720,20 @@ def rejection_greedy_sample_kernel(
for pos in range(num_draft_tokens): for pos in range(num_draft_tokens):
if not rejected: if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos).to(tl.int32)
if SYNTHETIC_MODE:
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
rate = tl.load(synthetic_conditional_rates_ptr + pos)
accepted = uniform_prob < rate
token_id = draft_token_id if accepted else target_argmax_id
rejected = not accepted
else:
token_id = target_argmax_id
rejected = draft_token_id != target_argmax_id
tl.store( tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id, token_id,
) )
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected: if not rejected:
# If all tokens are accepted, append the bonus token. # If all tokens are accepted, append the bonus token.
...@@ -707,7 +758,9 @@ def rejection_random_sample_kernel( ...@@ -707,7 +758,9 @@ def rejection_random_sample_kernel(
is_greedy_ptr, # [batch_size] is_greedy_ptr, # [batch_size]
max_spec_len, max_spec_len,
vocab_size, vocab_size,
synthetic_conditional_rates_ptr, # [num_speculative_tokens] or None
NO_DRAFT_PROBS: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr,
SYNTHETIC_MODE: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx) is_greedy = tl.load(is_greedy_ptr + req_idx)
...@@ -723,23 +776,28 @@ def rejection_random_sample_kernel( ...@@ -723,23 +776,28 @@ def rejection_random_sample_kernel(
for pos in range(num_draft_tokens): for pos in range(num_draft_tokens):
if not rejected: if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS: uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
draft_prob = 1 if SYNTHETIC_MODE:
rate = tl.load(synthetic_conditional_rates_ptr + pos)
accepted = uniform_prob < rate
else: else:
draft_prob = tl.load( if NO_DRAFT_PROBS:
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id draft_prob = 1
else:
draft_prob = tl.load(
draft_probs_ptr
+ (start_idx + pos) * vocab_size
+ draft_token_id
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
) )
target_prob = tl.load( # NOTE(woosuk): While the draft probability should never be 0,
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id # we check it to avoid NaNs. If it happens to be 0, we reject.
) accepted = draft_prob > 0 and target_prob / draft_prob >= uniform_prob
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) if accepted:
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id token_id = draft_token_id
else: else:
# Reject. Use recovered token.
rejected = True rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store( tl.store(
......
...@@ -594,3 +594,9 @@ def update_num_computed_tokens_for_batch_change( ...@@ -594,3 +594,9 @@ def update_num_computed_tokens_for_batch_change(
num_accepted_tokens.copy_( num_accepted_tokens.copy_(
torch.where(participating, valid_counts, num_accepted_tokens) torch.where(participating, valid_counts, num_accepted_tokens)
) )
def unconditional_to_conditional_rates(rates: list[float]) -> list[float]:
"""Convert per-position unconditional rates to per-position conditional
rates for the early-terminating rejection loop (c_i = p_i / p_{i-1})."""
return [p / q if q > 0.0 else 0.0 for p, q in zip(rates, [1.0, *rates[:-1]])]
...@@ -220,6 +220,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -220,6 +220,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.rejection_sampler = RejectionSampler( self.rejection_sampler = RejectionSampler(
self.sampler, self.sampler,
self.speculative_config, self.speculative_config,
self.device,
) )
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.structured_outputs_worker = StructuredOutputsWorker( self.structured_outputs_worker = StructuredOutputsWorker(
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from vllm.config import SpeculativeConfig from vllm.config import SpeculativeConfig
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.spec_decode.utils import unconditional_to_conditional_rates
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
...@@ -15,7 +16,6 @@ from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import ...@@ -15,7 +16,6 @@ from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import
probabilistic_rejection_sample, probabilistic_rejection_sample,
) )
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import ( from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params,
synthetic_rejection_sample, synthetic_rejection_sample,
) )
...@@ -102,24 +102,20 @@ class RejectionSampler: ...@@ -102,24 +102,20 @@ class RejectionSampler:
self, self,
sampler: Sampler, sampler: Sampler,
spec_config: SpeculativeConfig, spec_config: SpeculativeConfig,
device: torch.device,
): ):
self.sampler = sampler self.sampler = sampler
self.num_speculative_steps = spec_config.num_speculative_tokens self.num_speculative_steps = spec_config.num_speculative_tokens
self.rejection_sample_method = spec_config.rejection_sample_method self.rejection_sample_method = spec_config.rejection_sample_method
self.synthetic_conditional_rates: torch.Tensor | None = None
if self.rejection_sample_method == "synthetic": if self.rejection_sample_method == "synthetic":
synthetic_acceptance_rate = spec_config.synthetic_acceptance_rate assert spec_config.synthetic_acceptance_rates is not None
if ( self.synthetic_conditional_rates = torch.tensor(
synthetic_acceptance_rate is None unconditional_to_conditional_rates(
or not 0.0 <= synthetic_acceptance_rate <= 1.0 spec_config.synthetic_acceptance_rates
): ),
raise ValueError( dtype=torch.float32,
f"synthetic_acceptance_rate must be in [0, 1], " device=device,
f"but got {synthetic_acceptance_rate}"
)
self.base_acceptance_rate, self.decay_factor = (
compute_synthetic_rejection_sampler_params(
synthetic_acceptance_rate, self.num_speculative_steps
)
) )
def _get_logprobs_tensors( def _get_logprobs_tensors(
...@@ -218,8 +214,7 @@ class RejectionSampler: ...@@ -218,8 +214,7 @@ class RejectionSampler:
input_batch.positions[input_batch.logits_indices], input_batch.positions[input_batch.logits_indices],
input_batch.idx_mapping, input_batch.idx_mapping,
self.sampler.sampling_states.seeds.gpu, self.sampler.sampling_states.seeds.gpu,
self.base_acceptance_rate, self.synthetic_conditional_rates,
self.decay_factor,
self.num_speculative_steps, self.num_speculative_steps,
) )
else: else:
......
...@@ -5,8 +5,6 @@ import torch ...@@ -5,8 +5,6 @@ import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.gumbel import tl_rand64 from vllm.v1.worker.gpu.sample.gumbel import tl_rand64
MIN_ACCEPTANCE_DECAY_FACTOR = 0.85
@triton.jit @triton.jit
def _synthetic_rejection_sample_kernel( def _synthetic_rejection_sample_kernel(
...@@ -27,8 +25,8 @@ def _synthetic_rejection_sample_kernel( ...@@ -27,8 +25,8 @@ def _synthetic_rejection_sample_kernel(
idx_mapping_ptr, idx_mapping_ptr,
# [max_num_reqs] # [max_num_reqs]
seeds_ptr, seeds_ptr,
base_acceptance_rate, # [num_speculative_steps]
decay_factor, acceptance_rates_ptr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx) start_idx = tl.load(cu_num_logits_ptr + req_idx)
...@@ -38,13 +36,13 @@ def _synthetic_rejection_sample_kernel( ...@@ -38,13 +36,13 @@ def _synthetic_rejection_sample_kernel(
seed = tl.load(seeds_ptr + req_state_idx) seed = tl.load(seeds_ptr + req_state_idx)
num_sampled = 0 num_sampled = 0
acceptance_rate = base_acceptance_rate
rejected = False rejected = False
for i in range(num_tokens - 1): for i in range(num_tokens - 1):
if not rejected: if not rejected:
logit_idx = start_idx + i logit_idx = start_idx + i
pos = tl.load(pos_ptr + logit_idx) pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False) u = tl_rand64(seed, pos, includes_zero=False)
acceptance_rate = tl.load(acceptance_rates_ptr + i)
if u < acceptance_rate: if u < acceptance_rate:
sampled = tl.load(input_ids_ptr + logit_idx + 1).to(tl.int64) sampled = tl.load(input_ids_ptr + logit_idx + 1).to(tl.int64)
else: else:
...@@ -52,7 +50,6 @@ def _synthetic_rejection_sample_kernel( ...@@ -52,7 +50,6 @@ def _synthetic_rejection_sample_kernel(
rejected = True rejected = True
tl.store(sampled_ptr + req_idx * sampled_stride + i, sampled) tl.store(sampled_ptr + req_idx * sampled_stride + i, sampled)
num_sampled += 1 num_sampled += 1
acceptance_rate *= decay_factor
if not rejected: if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1) target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store( tl.store(
...@@ -75,8 +72,8 @@ def synthetic_rejection_sample( ...@@ -75,8 +72,8 @@ def synthetic_rejection_sample(
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
# [max_num_reqs] # [max_num_reqs]
seed: torch.Tensor, seed: torch.Tensor,
base_acceptance_rate: float, # [num_speculative_steps]
decay_factor: float, acceptance_rates: torch.Tensor,
num_speculative_steps: int, num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1 num_reqs = cu_num_logits.shape[0] - 1
...@@ -92,56 +89,7 @@ def synthetic_rejection_sample( ...@@ -92,56 +89,7 @@ def synthetic_rejection_sample(
pos, pos,
idx_mapping, idx_mapping,
seed, seed,
base_acceptance_rate, acceptance_rates,
decay_factor,
num_warps=1, num_warps=1,
) )
return sampled, num_sampled return sampled, num_sampled
def compute_synthetic_rejection_sampler_params(
p_avg: float, n: int, tol: float = 1e-9
) -> tuple[float, float]:
def mean_joint_prob(a_0: float, gamma: float, n: int):
total = 0.0
for i in range(n):
total += a_0 ** (i + 1) * gamma ** (i * (i + 1) // 2)
return total / n
def min_valid_decay_factor(p: float, n: int, tol: float = 1e-9) -> float:
low, high = MIN_ACCEPTANCE_DECAY_FACTOR, 1.0
if mean_joint_prob(1, low, n) >= p:
return low
# Sweep for a gamma decay factor that is guaranteed
# to yield a base acceptance rate <= 1.
while (high - low) > tol:
mid = (low + high) / 2
if mean_joint_prob(1, mid, n) >= p:
high = mid
else:
low = mid
return high
def compute_base_acceptance_rate(
p_avg: float, gamma: float, n: int, tol: float = 1e-9
) -> float:
if p_avg <= 0.0:
return 0.0
if p_avg >= 1.0:
return 1.0
# Sweep for a base acceptance rate that yields
# the desired mean joint probability.
low, high = 0.0, 1.0
while (high - low) > tol:
mid = (low + high) / 2
if mean_joint_prob(mid, gamma, n) >= p_avg:
high = mid
else:
low = mid
return high
decay_factor = min_valid_decay_factor(p_avg, n)
base_rate = compute_base_acceptance_rate(p_avg, decay_factor, n)
return base_rate, decay_factor
...@@ -577,7 +577,9 @@ class GPUModelRunner( ...@@ -577,7 +577,9 @@ class GPUModelRunner(
"Unknown speculative decoding method: " "Unknown speculative decoding method: "
f"{self.speculative_config.method}" f"{self.speculative_config.method}"
) )
self.rejection_sampler = RejectionSampler(self.sampler) self.rejection_sampler = RejectionSampler(
self.sampler, self.speculative_config, self.device
)
self.num_spec_tokens = 0 self.num_spec_tokens = 0
self.valid_sampled_token_count_gpu: torch.Tensor | None = None self.valid_sampled_token_count_gpu: torch.Tensor | None = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment