Unverified Commit 5daf6227 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2] Fuse probabilistic rejection sample kernels (#38496)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent ad330442
...@@ -100,11 +100,13 @@ steps: ...@@ -100,11 +100,13 @@ steps:
- vllm/v1/worker/gpu/ - vllm/v1/worker/gpu/
- vllm/v1/worker/gpu_worker.py - vllm/v1/worker/gpu_worker.py
- tests/v1/spec_decode/test_max_len.py - tests/v1/spec_decode/test_max_len.py
- tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
- tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py - tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
- tests/v1/e2e/spec_decode/test_spec_decode.py - tests/v1/e2e/spec_decode/test_spec_decode.py
commands: commands:
- set -x - set -x
- export VLLM_USE_V2_MODEL_RUNNER=1 - export VLLM_USE_V2_MODEL_RUNNER=1
- pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp" - pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
- pytest -v -s v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
- pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py - pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
- pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp" - pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
probabilistic_rejection_sample,
)
VOCAB_SIZE = 4096
# Skip if no CUDA - Triton kernel requires GPU
pytest.importorskip("triton")
if not torch.cuda.is_available():
pytest.skip("CUDA required for rejection sampler tests", allow_module_level=True)
def _build_rejection_sample_inputs(
target_logits_1d: torch.Tensor,
draft_logits_1d: torch.Tensor,
num_speculative_steps: int,
temperature: float,
num_trials: int,
) -> dict:
device = target_logits_1d.device
vocab_size = target_logits_1d.shape[0]
K = num_speculative_steps
num_logits = num_trials * (K + 1)
target_logits = target_logits_1d.unsqueeze(0).expand(num_logits, -1).contiguous()
draft_logits = (
draft_logits_1d.view(1, 1, vocab_size).expand(num_trials, K, -1).contiguous()
)
draft_probs = torch.softmax(draft_logits_1d, dim=0)
draft_tokens = torch.multinomial(
draft_probs.expand(num_trials, -1), K, replacement=True
)
draft_sampled_2d = torch.zeros(num_trials, K + 1, dtype=torch.int64, device=device)
draft_sampled_2d[:, 1:] = draft_tokens
draft_sampled = draft_sampled_2d.reshape(-1)
cu_num_logits = torch.arange(num_trials + 1, dtype=torch.int32, device=device) * (
K + 1
)
pos = torch.arange(num_logits, dtype=torch.int32, device=device)
idx_mapping = torch.arange(num_trials, dtype=torch.int32, device=device)
expanded_idx_mapping = torch.arange(
num_trials, dtype=torch.int32, device=device
).repeat_interleave(K + 1)
expanded_local_pos = torch.arange(K + 1, dtype=torch.int32, device=device).repeat(
num_trials
)
temp_tensor = torch.full(
(num_trials,), temperature, dtype=torch.float32, device=device
)
seed = torch.arange(num_trials, dtype=torch.int64, device=device)
return dict(
target_logits=target_logits,
draft_logits=draft_logits,
draft_sampled=draft_sampled,
cu_num_logits=cu_num_logits,
pos=pos,
idx_mapping=idx_mapping,
expanded_idx_mapping=expanded_idx_mapping,
expanded_local_pos=expanded_local_pos,
temperature=temp_tensor,
seed=seed,
)
def _assert_distribution_match(
sampled_tokens: torch.Tensor,
target_probs: torch.Tensor,
device: str,
label: str = "",
min_expected: float = 5.0,
):
"""
Assert sampled tokens match the target distribution via a
chi-squared goodness-of-fit test. This is done by computing
observed vs expected token counts (target_probs * num_samples),
then checking that the chi-squared statistic is below a conservative
threshold. The threshold is set at df + 10*sqrt(2*df), which
corresponds to ~10 sigma under the chi-squared distribution's
normal approximation, effectively disallowing false positives.
NOTE: Tokens with expected count < min_expected are merged into
a single "other" bin to minimize chi-squared noise.
"""
num_samples = sampled_tokens.shape[0]
vocab_size = target_probs.shape[0]
observed = torch.zeros(vocab_size, device=device, dtype=torch.float32)
observed.scatter_add_(0, sampled_tokens, torch.ones(num_samples, device=device))
expected = target_probs * num_samples
sufficient = expected >= min_expected
obs_main = observed[sufficient]
exp_main = expected[sufficient]
obs_other = observed[~sufficient].sum().unsqueeze(0)
exp_other = expected[~sufficient].sum().unsqueeze(0)
if exp_other.item() >= min_expected:
obs_all = torch.cat([obs_main, obs_other])
exp_all = torch.cat([exp_main, exp_other])
else:
obs_all = obs_main
exp_all = exp_main
chi2 = ((obs_all - exp_all) ** 2 / exp_all).sum().item()
df = obs_all.shape[0] - 1
if df < 1:
# All samples were merged into < 2 bins, which is too
# few to evaluate.
return
threshold = df + 10 * math.sqrt(2 * df)
prefix = f"[{label}] " if label else ""
assert chi2 < threshold, (
f"{prefix}Chi-squared test failed: chi2={chi2:.1f}, "
f"df={df}, threshold={threshold:.1f}. "
f"Output distribution does not match target distribution."
)
@pytest.mark.parametrize(
"num_speculative_steps,temperature",
[
(1, 0.6),
(3, 0.6),
(1, 1.0),
(3, 1.0),
],
)
def test_stochastic_rejection_sample(num_speculative_steps: int, temperature: float):
"""
Verify that rejection sampling produces the target distribution.
This is done by simulating many independent trials of speculative
decoding (from a fixed target and draft distribution). We then
run rejection sample on all of the trials (requests), and verify
that the sampled tokens at every position follow the target
distribution p(x).
"""
torch.manual_seed(42)
device = "cuda"
num_trials = 10 * VOCAB_SIZE
target_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
draft_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
if temperature > 0:
target_logits_1d /= temperature
draft_logits_1d /= temperature
inputs = _build_rejection_sample_inputs(
target_logits_1d,
draft_logits_1d,
num_speculative_steps,
temperature=temperature,
num_trials=num_trials,
)
sampled, num_sampled = probabilistic_rejection_sample(
**inputs, num_speculative_steps=num_speculative_steps
)
target_probs = torch.softmax(target_logits_1d, dim=0)
for pos in range(num_speculative_steps + 1):
accepted_mask = num_sampled >= pos + 1
_assert_distribution_match(
sampled[accepted_mask, pos], target_probs, device, label=f"position {pos}"
)
@pytest.mark.parametrize("num_speculative_steps", [1, 3])
def test_greedy_rejection_sample(num_speculative_steps: int):
"""
Verify that greedy (temperature=0) always outputs the target argmax
at every accepted position.
"""
torch.manual_seed(42)
device = "cuda"
num_trials = 10 * VOCAB_SIZE
target_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
draft_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
inputs = _build_rejection_sample_inputs(
target_logits_1d,
draft_logits_1d,
num_speculative_steps,
temperature=0.0,
num_trials=num_trials,
)
sampled, num_sampled = probabilistic_rejection_sample(
**inputs, num_speculative_steps=num_speculative_steps
)
target_argmax = target_logits_1d.argmax().item()
steps = torch.arange(num_speculative_steps + 1, device=device).unsqueeze(0)
accepted_mask = steps < num_sampled.unsqueeze(1)
assert (sampled[accepted_mask] == target_argmax).all(), (
"Greedy sampling produced tokens that are not the target argmax"
)
...@@ -65,36 +65,20 @@ def tl_rand64(seed, offset, includes_zero: tl.constexpr): ...@@ -65,36 +65,20 @@ def tl_rand64(seed, offset, includes_zero: tl.constexpr):
@triton.jit @triton.jit
def _gumbel_sample_kernel( def gumbel_block_argmax(
local_argmax_ptr, logits,
local_argmax_stride, block,
local_max_ptr, mask,
local_max_stride, token_idx,
processed_logits_ptr,
processed_logits_stride,
logits_ptr,
logits_stride,
expanded_idx_mapping_ptr, expanded_idx_mapping_ptr,
temp_ptr,
seeds_ptr, seeds_ptr,
pos_ptr, pos_ptr,
temp_ptr, processed_logits_ptr,
vocab_size, processed_logits_stride,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr, APPLY_TEMPERATURE: tl.constexpr,
): ):
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp != 0.0 and APPLY_TEMPERATURE: if temp != 0.0 and APPLY_TEMPERATURE:
# Apply temperature. # Apply temperature.
...@@ -102,8 +86,8 @@ def _gumbel_sample_kernel( ...@@ -102,8 +86,8 @@ def _gumbel_sample_kernel(
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp logits = logits / temp
# Store the temperature-applied logits.
if processed_logits_ptr is not None: if processed_logits_ptr is not None:
# Store the temperature-applied logits.
tl.store( tl.store(
processed_logits_ptr + req_state_idx * processed_logits_stride + block, processed_logits_ptr + req_state_idx * processed_logits_stride + block,
logits, logits,
...@@ -126,6 +110,51 @@ def _gumbel_sample_kernel( ...@@ -126,6 +110,51 @@ def _gumbel_sample_kernel(
logits = tl.where(mask, logits + gumbel_noise, float("-inf")) logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
value, idx = tl.max(logits, axis=0, return_indices=True) value, idx = tl.max(logits, axis=0, return_indices=True)
return value, idx
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
processed_logits_ptr,
processed_logits_stride,
logits_ptr,
logits_stride,
expanded_idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
token_idx = tl.program_id(0)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
value, idx = gumbel_block_argmax(
logits,
block,
mask,
token_idx,
expanded_idx_mapping_ptr,
temp_ptr,
seeds_ptr,
pos_ptr,
processed_logits_ptr,
processed_logits_stride,
APPLY_TEMPERATURE=APPLY_TEMPERATURE,
)
token_id = block_idx * BLOCK_SIZE + idx token_id = block_idx * BLOCK_SIZE + idx
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id) tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value) tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
......
...@@ -7,11 +7,13 @@ from vllm.triton_utils import tl, triton ...@@ -7,11 +7,13 @@ from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample, tl_rand64
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
probabilistic_rejection_sample,
)
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import ( from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params, compute_synthetic_rejection_sampler_params,
synthetic_rejection_sample, synthetic_rejection_sample,
...@@ -75,357 +77,6 @@ def strict_rejection_sample( ...@@ -75,357 +77,6 @@ def strict_rejection_sample(
return sampled, num_sampled return sampled, num_sampled
@triton.jit
def _gather_draft_logits_and_target_argmax_kernel(
local_target_argmax_ptr,
local_target_argmax_stride,
local_target_max_ptr,
local_target_max_stride,
# [num_logits, V]
out_draft_logits_ptr,
out_draft_logits_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr,
draft_logits_stride_0,
draft_logits_stride_1,
# [num_logits]
expanded_idx_mapping_ptr,
# [num_logits]
expanded_local_pos_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size,
num_speculative_steps,
BLOCK_SIZE: tl.constexpr,
):
logit_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
block_idx = tl.program_id(1)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp == 0.0:
# Greedy sampling. Get the target logits argmax.
target_logits = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
value, idx = tl.max(target_logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(
local_target_argmax_ptr
+ logit_idx * local_target_argmax_stride
+ block_idx,
token_id,
)
tl.store(
local_target_max_ptr + logit_idx * local_target_max_stride + block_idx,
value,
)
elif draft_step_idx < num_speculative_steps:
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ draft_step_idx * draft_logits_stride_1
+ block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
tl.store(
out_draft_logits_ptr + logit_idx * out_draft_logits_stride + block_offsets,
draft_logits,
mask=mask,
)
@triton.jit
def _probabilistic_rejection_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
rejected_steps_ptr,
# [num_reqs]
rejected_pos_ptr,
# [num_logits]
draft_sampled_ptr,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_logits, V]
draft_probs_ptr,
draft_probs_stride,
# [num_logits, num_blocks]
local_target_argmax_ptr,
local_target_argmax_stride,
# [num_logits, num_blocks]
local_target_max_ptr,
local_target_max_stride,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
# [num_reqs]
idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
# [max_num_reqs]
seeds_ptr,
NUM_BLOCKS: tl.constexpr,
PADDED_NUM_BLOCKS: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
seed = tl.load(seeds_ptr + req_state_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
rejected_step = 0
accepted = True
for i in range(num_tokens - 1):
if accepted:
logit_idx = start_idx + i
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
if temp == 0.0:
# Greedy sampling. Only accept the sampled draft token if
# it exactly matches the target argmax.
block_offsets = tl.arange(0, PADDED_NUM_BLOCKS)
block_mask = block_offsets < NUM_BLOCKS
local_max = tl.load(
local_target_max_ptr
+ logit_idx * local_target_max_stride
+ block_offsets,
mask=block_mask,
other=float("-inf"),
)
max_block = tl.argmax(local_max, axis=0)
target_argmax = tl.load(
local_target_argmax_ptr
+ logit_idx * local_target_argmax_stride
+ max_block
)
accepted &= target_argmax == draft_sampled
else:
target_prob = tl.load(
target_probs_ptr + logit_idx * target_probs_stride + draft_sampled
).to(tl.float64)
draft_prob = tl.load(
draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled
).to(tl.float64)
pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False)
accepted &= target_prob > u * draft_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted
tl.store(rejected_steps_ptr + req_idx, rejected_step)
pos_val = tl.load(pos_ptr + start_idx + rejected_step)
tl.store(rejected_pos_ptr + req_idx, pos_val)
@triton.jit
def _compute_residual_logits_kernel(
# [num_reqs, V]
residual_logits_ptr,
residual_logits_stride,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_logits, V]
draft_probs_ptr,
draft_probs_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_reqs]
rejected_step_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_reqs]
idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
rejected_logit_idx = start_idx + tl.load(rejected_step_ptr + req_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
if temp == 0.0 or (rejected_logit_idx == end_idx - 1):
# Greedy sampling / bonus token. In either case, use the
# target logits directly to reduce numerical error.
residual_logits = tl.load(
target_logits_ptr
+ rejected_logit_idx * target_logits_stride
+ block_offsets,
mask=mask,
other=float("-inf"),
)
else:
target_probs = tl.load(
target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets,
mask=mask,
other=0.0,
)
draft_probs = tl.load(
draft_probs_ptr + rejected_logit_idx * draft_probs_stride + block_offsets,
mask=mask,
other=0.0,
)
residual_probs = tl.maximum(target_probs - draft_probs, 0.0)
residual_logits = tl.log(residual_probs)
tl.store(
residual_logits_ptr + req_idx * residual_logits_stride + block_offsets,
residual_logits,
mask=mask,
)
def probabilistic_rejection_sample(
# [num_logits, V]
target_logits: torch.Tensor,
# [max_num_reqs, num_speculative_steps, V]
draft_logits: torch.Tensor,
# [num_logits]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
# [num_logits]
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
# [num_logits]
expanded_idx_mapping: torch.Tensor,
# [num_logits]
expanded_local_pos: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seed: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
num_logits, vocab_size = target_logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
# Gather draft logits and target argmax for greedy sampling.
gathered_draft_logits = target_logits.new_empty(target_logits.shape)
local_target_argmax = target_logits.new_empty(
num_logits, num_blocks, dtype=torch.int64
)
local_target_max = target_logits.new_empty(
num_logits, num_blocks, dtype=torch.float32
)
_gather_draft_logits_and_target_argmax_kernel[(num_logits, num_blocks)](
local_target_argmax,
local_target_argmax.stride(0),
local_target_max,
local_target_max.stride(0),
gathered_draft_logits,
gathered_draft_logits.stride(0),
target_logits,
target_logits.stride(0),
draft_logits,
draft_logits.stride(0),
draft_logits.stride(1),
expanded_idx_mapping,
expanded_local_pos,
temperature,
vocab_size,
num_speculative_steps,
BLOCK_SIZE=BLOCK_SIZE,
)
# Compute target and draft probs.
target_probs = torch.softmax(target_logits, dim=-1)
draft_probs = torch.softmax(gathered_draft_logits, dim=-1)
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
sampled = draft_sampled.new_empty(
num_reqs, num_speculative_steps + 1, dtype=torch.int64
)
# [num_reqs]
rejected_steps = sampled.new_empty(num_reqs)
# [num_reqs]
rejected_pos = pos.new_empty(num_reqs)
_probabilistic_rejection_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
rejected_steps,
rejected_pos,
draft_sampled,
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
local_target_argmax,
local_target_argmax.stride(0),
local_target_max,
local_target_max.stride(0),
cu_num_logits,
pos,
idx_mapping,
temperature,
seed,
num_warps=1,
NUM_BLOCKS=num_blocks,
PADDED_NUM_BLOCKS=triton.next_power_of_2(num_blocks),
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits = target_logits.new_empty(num_reqs, vocab_size)
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
residual_logits,
residual_logits.stride(0),
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
target_logits,
target_logits.stride(0),
rejected_steps,
cu_num_logits,
idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Gumbel sample tokens from the residual distribution.
resampled = gumbel_sample(
residual_logits,
idx_mapping,
temperature,
seed,
rejected_pos,
apply_temperature=False,
)
sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1))
return sampled, rejected_steps + 1
@triton.jit @triton.jit
def _flatten_sampled_kernel( def _flatten_sampled_kernel(
# [num_logits] # [num_logits]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment