Unverified Commit 626daa20 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Feat] Unified Synthetic Acceptance Rate for V1 and V2 (#40662)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
Signed-off-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent fe85a92e
......@@ -1310,6 +1310,54 @@ def test_dflash_acceptance_rates(dflash_config):
cleanup_dist_env_and_memory()
@single_gpu_only
def test_synthetic_acceptance_rate():
"""Verify that synthetic rejection sampling produces an acceptance
length close to the requested mean acceptance length."""
num_spec_tokens = 3
expected_acceptance_len = 1.875
tolerance = 0.15
spec_llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
trust_remote_code=True,
speculative_config={
"method": "eagle3",
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
"num_speculative_tokens": num_spec_tokens,
"max_model_len": 2048,
"rejection_sample_method": "synthetic",
"synthetic_acceptance_length": expected_acceptance_len,
},
max_model_len=2048,
enforce_eager=True,
disable_log_stats=False,
)
test_prompts = get_test_prompts(mm_enabled=False, num_prompts=50)
spec_llm.chat(
test_prompts,
SamplingParams(temperature=0, max_tokens=64, ignore_eos=True),
)
metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(metrics)
print(
f"Synthetic acceptance length: {acceptance_len:.3f}"
f" (expected={expected_acceptance_len:.3f},"
f" tolerance=±{tolerance})"
)
assert abs(acceptance_len - expected_acceptance_len) <= tolerance, (
f"Synthetic acceptance length {acceptance_len:.3f} is not within"
f" ±{tolerance} of expected {expected_acceptance_len:.3f}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
def test_dflash_correctness(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
......
......@@ -933,3 +933,64 @@ def test_sample_recovered_tokens(
device=DEVICE_TYPE,
)
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
########################### Tests for Synthetic Rejection Sampling #########
def _make_synthetic_sampler(rates: list[float]) -> RejectionSampler:
mock_sampler = Mock(spec=Sampler)
mock_sampler.logprobs_mode = "raw_logprobs"
spec_config = Mock()
spec_config.rejection_sample_method = "synthetic"
spec_config.synthetic_acceptance_rates = rates
return RejectionSampler(mock_sampler, spec_config, torch.device(DEVICE_TYPE))
def _make_sampling_metadata(all_greedy: bool) -> SamplingMetadata:
temperature = None if all_greedy else torch.tensor([1.0, 1.0], device=DEVICE_TYPE)
return create_sampling_metadata(all_greedy=all_greedy, temperature=temperature)
@pytest.mark.parametrize("all_greedy", [True, False])
def test_synthetic_all_accepted(all_greedy: bool):
"""With all rates=1.0, every draft token is accepted."""
sampler = _make_synthetic_sampler([1.0, 1.0])
spec_tokens = [[1, 2], [3]]
output_tokens = [[10, 20, 50], [30, 40]]
metadata = _make_sampling_metadata(all_greedy)
logits = create_logits_tensor(output_tokens)
bonus = torch.tensor([50, 40], device=DEVICE_TYPE)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(sampler, bonus)
output = sampler(spec_decode_metadata, None, logits, metadata)
expected = torch.tensor(
[[1, 2, 50], [3, 40, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=DEVICE_TYPE,
)
assert torch.equal(output.sampled_token_ids, expected)
@pytest.mark.parametrize("all_greedy", [True, False])
def test_synthetic_all_rejected(all_greedy: bool):
"""With all rates=0.0, the first token is always rejected."""
sampler = _make_synthetic_sampler([0.0, 0.0])
spec_tokens = [[1, 2], [3]]
output_tokens = [[10, 20, 50], [30, 40]]
metadata = _make_sampling_metadata(all_greedy)
logits = create_logits_tensor(output_tokens)
bonus = torch.tensor([50, 40], device=DEVICE_TYPE)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(sampler, bonus)
output = sampler(spec_decode_metadata, None, logits, metadata)
result = output.sampled_token_ids
# Exactly one token emitted per sequence (the rejection fallback),
# followed by placeholders.
for row in result:
assert row[0] != PLACEHOLDER_TOKEN_ID
assert (row[1:] == PLACEHOLDER_TOKEN_ID).all()
......@@ -2,33 +2,48 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params,
)
from vllm.config.speculative import SpeculativeConfig
from vllm.v1.spec_decode.utils import unconditional_to_conditional_rates
def test_unconditional_to_conditional_rates_basic():
# c_0 = p_0; c_i = p_i / p_{i-1}
assert unconditional_to_conditional_rates([0.9, 0.5, 0.2]) == pytest.approx(
[0.9, 0.5 / 0.9, 0.2 / 0.5]
)
NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10]
ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)]
def test_unconditional_to_conditional_rates_handles_zero():
# After a zero, subsequent conditional rates are clamped to 0 (the chain
# has already terminated in the kernel, so these values are unused).
assert unconditional_to_conditional_rates([1.0, 0.6, 0.0, 0.0]) == pytest.approx(
[1.0, 0.6, 0.0, 0.0]
)
@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS)
def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int):
"""Test that the base acceptance rate and decay factor generated for
synthetic rejection sampling have a mean joint acceptance probability
that matches the desired acceptance rate."""
tol = 1e-9
for desired_acceptance_rate in ACCEPTANCE_RATES:
base_rate, decay_factor = compute_synthetic_rejection_sampler_params(
desired_acceptance_rate, num_speculative_steps, tol=tol
def test_unconditional_to_conditional_rates_all_ones():
assert unconditional_to_conditional_rates([1.0, 1.0, 1.0]) == pytest.approx(
[1.0, 1.0, 1.0]
)
# Compute the mean of joint acceptance probabilities across
# all speculative positions.
joint_prob = 1.0
mean_joint = 0.0
for i in range(num_speculative_steps):
joint_prob *= base_rate * decay_factor**i
mean_joint += joint_prob
mean_joint /= num_speculative_steps
assert abs(desired_acceptance_rate - mean_joint) < 10 * tol
assert base_rate <= 1.0
@pytest.mark.parametrize(
"length,n,expected",
[
(2.6, 3, [1.0, 0.6, 0.0]),
(1.0, 3, [0.0, 0.0, 0.0]),
(4.0, 3, [1.0, 1.0, 1.0]),
(2.0, 3, [1.0, 0.0, 0.0]),
(3.5, 4, [1.0, 1.0, 0.5, 0.0]),
],
)
def test_acceptance_length_to_rates(length, n, expected):
assert SpeculativeConfig._acceptance_length_to_rates(length, n) == pytest.approx(
expected
)
def test_resolve_length_produces_minvariance_schedule():
assert SpeculativeConfig._resolve_synthetic_acceptance_rates(
3, None, 2.6
) == pytest.approx([1.0, 0.6, 0.0])
......@@ -189,12 +189,64 @@ class SpeculativeConfig:
distribution, but the latter yields a higher acceptance rate at the cost
of more memory to cache draft logits."""
synthetic_acceptance_rate: float | None = None
"""Average acceptance rate for synthetic rejection sampling. Draft
tokens are accepted with a position-dependent probability that decays
geometrically, calibrated so that the mean rate across all speculative
positions equals this value. Only used when rejection_sample_method
is 'synthetic'. Must be in [0, 1]."""
synthetic_acceptance_rates: list[float] | None = None
"""Per-position *unconditional* acceptance rates for synthetic rejection
sampling. Position i's entry is the marginal probability that the first
i+1 draft tokens are all accepted; the list must have length
num_speculative_tokens, each entry in [0, 1], and be monotonically
non-increasing. Only valid when rejection_sample_method is 'synthetic'.
Mutually exclusive with synthetic_acceptance_length."""
synthetic_acceptance_length: float | None = None
"""Target mean acceptance length for synthetic rejection sampling, in
[1, num_speculative_tokens + 1]. Resolved internally to
synthetic_acceptance_rates. Only valid when rejection_sample_method is 'synthetic'.
Mutually exclusive with synthetic_acceptance_rates."""
@staticmethod
def _acceptance_length_to_rates(length: float, n: int) -> list[float]:
"""Mean acceptance length to unconditional per-position rates, using
the minimum-variance schedule."""
num_drafts = length - 1 # expected number of accepted draft tokens
num_full = int(num_drafts)
return (
[1.0] * num_full + [num_drafts - num_full] + [0.0] * (n - num_full - 1)
)[:n]
@staticmethod
def _resolve_synthetic_acceptance_rates(
n: int,
rates: list[float] | None,
length: float | None,
) -> list[float]:
"""Return per-position unconditional acceptance rates from exactly one
of `rates` or `length` (validates range, length, and monotonicity)."""
if (rates is None) == (length is None):
raise ValueError(
"rejection_sample_method='synthetic' requires exactly one of "
"synthetic_acceptance_rates or synthetic_acceptance_length."
)
if rates is not None:
if len(rates) != n:
raise ValueError(
f"synthetic_acceptance_rates must have length {n}, got {rates}."
)
if not all(0.0 <= r <= 1.0 for r in rates):
raise ValueError(
f"synthetic_acceptance_rates entries must be in [0, 1], "
f"got {rates}."
)
if any(rates[i] > rates[i - 1] for i in range(1, n)):
raise ValueError(
f"synthetic_acceptance_rates must be non-increasing, got {rates}."
)
return list(rates)
assert length is not None
if not 1.0 <= length <= float(n + 1):
raise ValueError(
f"synthetic_acceptance_length must be in [1, {n + 1}], got {length}."
)
return SpeculativeConfig._acceptance_length_to_rates(length, n)
def compute_hash(self) -> str:
"""
......@@ -818,6 +870,23 @@ class SpeculativeConfig:
f"than zero ({self.num_speculative_tokens})."
)
if self.rejection_sample_method == "synthetic":
# Consolidate to per-position rates
self.synthetic_acceptance_rates = self._resolve_synthetic_acceptance_rates(
self.num_speculative_tokens,
self.synthetic_acceptance_rates,
self.synthetic_acceptance_length,
)
self.synthetic_acceptance_length = None
elif (
self.synthetic_acceptance_rates is not None
or self.synthetic_acceptance_length is not None
):
raise ValueError(
"synthetic_acceptance_rates / synthetic_acceptance_length "
"are only valid with rejection_sample_method='synthetic'."
)
if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import replace
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
......@@ -17,6 +20,10 @@ from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.utils import unconditional_to_conditional_rates
if TYPE_CHECKING:
from vllm.config.speculative import SpeculativeConfig
logger = init_logger(__name__)
......@@ -50,13 +57,33 @@ class RejectionSampler(nn.Module):
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def __init__(self, sampler: Sampler):
def __init__(
self,
sampler: Sampler,
spec_config: SpeculativeConfig | None = None,
device: torch.device | None = None,
):
super().__init__()
self.sampler = sampler
logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
self.synthetic_conditional_rates: torch.Tensor | None = None
if (
spec_config is not None
and spec_config.rejection_sample_method == "synthetic"
):
assert spec_config.synthetic_acceptance_rates is not None
self.synthetic_conditional_rates = torch.tensor(
unconditional_to_conditional_rates(
spec_config.synthetic_acceptance_rates
),
dtype=torch.float32,
device=device,
)
self.synthetic_mode = self.synthetic_conditional_rates is not None
def forward(
self,
metadata: SpecDecodeMetadata,
......@@ -147,6 +174,8 @@ class RejectionSampler(nn.Module):
target_logits,
bonus_token_ids,
sampling_metadata,
synthetic_mode=self.synthetic_mode,
synthetic_conditional_rates=self.synthetic_conditional_rates,
)
logprobs_tensors = None
......@@ -362,6 +391,8 @@ def rejection_sample(
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
synthetic_mode: bool = False,
synthetic_conditional_rates: torch.Tensor | None = None,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
......@@ -389,6 +420,20 @@ def rejection_sample(
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
# Generate uniform probabilities before either kernel because synthetic
# mode needs them in the greedy kernel too. Skip only when all requests
# are greedy *and* synthetic mode is off (the standard fast-path).
# [num_tokens]
uniform_probs: torch.Tensor | None = None
if synthetic_mode or not sampling_metadata.all_greedy:
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_logits.argmax(dim=-1)
......@@ -400,6 +445,9 @@ def rejection_sample(
bonus_token_ids,
is_greedy,
max_spec_len,
uniform_probs,
synthetic_conditional_rates,
SYNTHETIC_MODE=synthetic_mode,
)
if sampling_metadata.all_greedy:
return output_token_ids
......@@ -408,15 +456,6 @@ def rejection_sample(
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
assert target_probs.is_contiguous()
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
......@@ -431,6 +470,7 @@ def rejection_sample(
)
# Rejection sampling for random sampling requests.
assert uniform_probs is not None
rejection_random_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
......@@ -443,7 +483,9 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
synthetic_conditional_rates,
NO_DRAFT_PROBS=draft_probs is None,
SYNTHETIC_MODE=synthetic_mode,
)
return output_token_ids
......@@ -658,6 +700,9 @@ def rejection_greedy_sample_kernel(
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
uniform_probs_ptr, # [num_tokens] or None (synthetic mode only)
synthetic_conditional_rates_ptr, # [num_speculative_tokens] or None
SYNTHETIC_MODE: tl.constexpr,
):
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
......@@ -675,14 +720,20 @@ def rejection_greedy_sample_kernel(
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos).to(tl.int32)
if SYNTHETIC_MODE:
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
rate = tl.load(synthetic_conditional_rates_ptr + pos)
accepted = uniform_prob < rate
token_id = draft_token_id if accepted else target_argmax_id
rejected = not accepted
else:
token_id = target_argmax_id
rejected = draft_token_id != target_argmax_id
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
token_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected:
# If all tokens are accepted, append the bonus token.
......@@ -707,7 +758,9 @@ def rejection_random_sample_kernel(
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
synthetic_conditional_rates_ptr, # [num_speculative_tokens] or None
NO_DRAFT_PROBS: tl.constexpr,
SYNTHETIC_MODE: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
......@@ -723,23 +776,28 @@ def rejection_random_sample_kernel(
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
if SYNTHETIC_MODE:
rate = tl.load(synthetic_conditional_rates_ptr + pos)
accepted = uniform_prob < rate
else:
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
draft_probs_ptr
+ (start_idx + pos) * vocab_size
+ draft_token_id
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
accepted = draft_prob > 0 and target_prob / draft_prob >= uniform_prob
if accepted:
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(
......
......@@ -594,3 +594,9 @@ def update_num_computed_tokens_for_batch_change(
num_accepted_tokens.copy_(
torch.where(participating, valid_counts, num_accepted_tokens)
)
def unconditional_to_conditional_rates(rates: list[float]) -> list[float]:
"""Convert per-position unconditional rates to per-position conditional
rates for the early-terminating rejection loop (c_i = p_i / p_{i-1})."""
return [p / q if q > 0.0 else 0.0 for p, q in zip(rates, [1.0, *rates[:-1]])]
......@@ -220,6 +220,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.rejection_sampler = RejectionSampler(
self.sampler,
self.speculative_config,
self.device,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.structured_outputs_worker = StructuredOutputsWorker(
......
......@@ -5,6 +5,7 @@ import torch
from vllm.config import SpeculativeConfig
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.spec_decode.utils import unconditional_to_conditional_rates
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
......@@ -15,7 +16,6 @@ from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import
probabilistic_rejection_sample,
)
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params,
synthetic_rejection_sample,
)
......@@ -102,24 +102,20 @@ class RejectionSampler:
self,
sampler: Sampler,
spec_config: SpeculativeConfig,
device: torch.device,
):
self.sampler = sampler
self.num_speculative_steps = spec_config.num_speculative_tokens
self.rejection_sample_method = spec_config.rejection_sample_method
self.synthetic_conditional_rates: torch.Tensor | None = None
if self.rejection_sample_method == "synthetic":
synthetic_acceptance_rate = spec_config.synthetic_acceptance_rate
if (
synthetic_acceptance_rate is None
or not 0.0 <= synthetic_acceptance_rate <= 1.0
):
raise ValueError(
f"synthetic_acceptance_rate must be in [0, 1], "
f"but got {synthetic_acceptance_rate}"
)
self.base_acceptance_rate, self.decay_factor = (
compute_synthetic_rejection_sampler_params(
synthetic_acceptance_rate, self.num_speculative_steps
)
assert spec_config.synthetic_acceptance_rates is not None
self.synthetic_conditional_rates = torch.tensor(
unconditional_to_conditional_rates(
spec_config.synthetic_acceptance_rates
),
dtype=torch.float32,
device=device,
)
def _get_logprobs_tensors(
......@@ -218,8 +214,7 @@ class RejectionSampler:
input_batch.positions[input_batch.logits_indices],
input_batch.idx_mapping,
self.sampler.sampling_states.seeds.gpu,
self.base_acceptance_rate,
self.decay_factor,
self.synthetic_conditional_rates,
self.num_speculative_steps,
)
else:
......
......@@ -5,8 +5,6 @@ import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.gumbel import tl_rand64
MIN_ACCEPTANCE_DECAY_FACTOR = 0.85
@triton.jit
def _synthetic_rejection_sample_kernel(
......@@ -27,8 +25,8 @@ def _synthetic_rejection_sample_kernel(
idx_mapping_ptr,
# [max_num_reqs]
seeds_ptr,
base_acceptance_rate,
decay_factor,
# [num_speculative_steps]
acceptance_rates_ptr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
......@@ -38,13 +36,13 @@ def _synthetic_rejection_sample_kernel(
seed = tl.load(seeds_ptr + req_state_idx)
num_sampled = 0
acceptance_rate = base_acceptance_rate
rejected = False
for i in range(num_tokens - 1):
if not rejected:
logit_idx = start_idx + i
pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False)
acceptance_rate = tl.load(acceptance_rates_ptr + i)
if u < acceptance_rate:
sampled = tl.load(input_ids_ptr + logit_idx + 1).to(tl.int64)
else:
......@@ -52,7 +50,6 @@ def _synthetic_rejection_sample_kernel(
rejected = True
tl.store(sampled_ptr + req_idx * sampled_stride + i, sampled)
num_sampled += 1
acceptance_rate *= decay_factor
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
......@@ -75,8 +72,8 @@ def synthetic_rejection_sample(
idx_mapping: torch.Tensor,
# [max_num_reqs]
seed: torch.Tensor,
base_acceptance_rate: float,
decay_factor: float,
# [num_speculative_steps]
acceptance_rates: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
......@@ -92,56 +89,7 @@ def synthetic_rejection_sample(
pos,
idx_mapping,
seed,
base_acceptance_rate,
decay_factor,
acceptance_rates,
num_warps=1,
)
return sampled, num_sampled
def compute_synthetic_rejection_sampler_params(
p_avg: float, n: int, tol: float = 1e-9
) -> tuple[float, float]:
def mean_joint_prob(a_0: float, gamma: float, n: int):
total = 0.0
for i in range(n):
total += a_0 ** (i + 1) * gamma ** (i * (i + 1) // 2)
return total / n
def min_valid_decay_factor(p: float, n: int, tol: float = 1e-9) -> float:
low, high = MIN_ACCEPTANCE_DECAY_FACTOR, 1.0
if mean_joint_prob(1, low, n) >= p:
return low
# Sweep for a gamma decay factor that is guaranteed
# to yield a base acceptance rate <= 1.
while (high - low) > tol:
mid = (low + high) / 2
if mean_joint_prob(1, mid, n) >= p:
high = mid
else:
low = mid
return high
def compute_base_acceptance_rate(
p_avg: float, gamma: float, n: int, tol: float = 1e-9
) -> float:
if p_avg <= 0.0:
return 0.0
if p_avg >= 1.0:
return 1.0
# Sweep for a base acceptance rate that yields
# the desired mean joint probability.
low, high = 0.0, 1.0
while (high - low) > tol:
mid = (low + high) / 2
if mean_joint_prob(mid, gamma, n) >= p_avg:
high = mid
else:
low = mid
return high
decay_factor = min_valid_decay_factor(p_avg, n)
base_rate = compute_base_acceptance_rate(p_avg, decay_factor, n)
return base_rate, decay_factor
......@@ -577,7 +577,9 @@ class GPUModelRunner(
"Unknown speculative decoding method: "
f"{self.speculative_config.method}"
)
self.rejection_sampler = RejectionSampler(self.sampler)
self.rejection_sampler = RejectionSampler(
self.sampler, self.speculative_config, self.device
)
self.num_spec_tokens = 0
self.valid_sampled_token_count_gpu: torch.Tensor | None = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment