Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5daf6227
Unverified
Commit
5daf6227
authored
Apr 07, 2026
by
Giancarlo Delfin
Committed by
GitHub
Apr 07, 2026
Browse files
[Model Runner V2] Fuse probabilistic rejection sample kernels (#38496)
Signed-off-by:
Giancarlo Delfin
<
gdelfin@inferact.ai
>
parent
ad330442
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
886 additions
and
377 deletions
+886
-377
.buildkite/test_areas/model_runner_v2.yaml
.buildkite/test_areas/model_runner_v2.yaml
+2
-0
tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
...spec_decode/test_probabilistic_rejection_sampler_utils.py
+215
-0
vllm/v1/worker/gpu/sample/gumbel.py
vllm/v1/worker/gpu/sample/gumbel.py
+54
-25
vllm/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py
.../gpu/spec_decode/probabilistic_rejection_sampler_utils.py
+612
-0
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
+3
-352
No files found.
.buildkite/test_areas/model_runner_v2.yaml
View file @
5daf6227
...
@@ -100,11 +100,13 @@ steps:
...
@@ -100,11 +100,13 @@ steps:
-
vllm/v1/worker/gpu/
-
vllm/v1/worker/gpu/
-
vllm/v1/worker/gpu_worker.py
-
vllm/v1/worker/gpu_worker.py
-
tests/v1/spec_decode/test_max_len.py
-
tests/v1/spec_decode/test_max_len.py
-
tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
-
tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
-
tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
-
tests/v1/e2e/spec_decode/test_spec_decode.py
-
tests/v1/e2e/spec_decode/test_spec_decode.py
commands
:
commands
:
-
set -x
-
set -x
-
export VLLM_USE_V2_MODEL_RUNNER=1
-
export VLLM_USE_V2_MODEL_RUNNER=1
-
pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
-
pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
-
pytest -v -s v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
-
pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
-
pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
-
pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"
-
pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"
tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
0 → 100644
View file @
5daf6227
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
pytest
import
torch
from
vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils
import
(
probabilistic_rejection_sample
,
)
VOCAB_SIZE
=
4096
# Skip if no CUDA - Triton kernel requires GPU
pytest
.
importorskip
(
"triton"
)
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA required for rejection sampler tests"
,
allow_module_level
=
True
)
def
_build_rejection_sample_inputs
(
target_logits_1d
:
torch
.
Tensor
,
draft_logits_1d
:
torch
.
Tensor
,
num_speculative_steps
:
int
,
temperature
:
float
,
num_trials
:
int
,
)
->
dict
:
device
=
target_logits_1d
.
device
vocab_size
=
target_logits_1d
.
shape
[
0
]
K
=
num_speculative_steps
num_logits
=
num_trials
*
(
K
+
1
)
target_logits
=
target_logits_1d
.
unsqueeze
(
0
).
expand
(
num_logits
,
-
1
).
contiguous
()
draft_logits
=
(
draft_logits_1d
.
view
(
1
,
1
,
vocab_size
).
expand
(
num_trials
,
K
,
-
1
).
contiguous
()
)
draft_probs
=
torch
.
softmax
(
draft_logits_1d
,
dim
=
0
)
draft_tokens
=
torch
.
multinomial
(
draft_probs
.
expand
(
num_trials
,
-
1
),
K
,
replacement
=
True
)
draft_sampled_2d
=
torch
.
zeros
(
num_trials
,
K
+
1
,
dtype
=
torch
.
int64
,
device
=
device
)
draft_sampled_2d
[:,
1
:]
=
draft_tokens
draft_sampled
=
draft_sampled_2d
.
reshape
(
-
1
)
cu_num_logits
=
torch
.
arange
(
num_trials
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
*
(
K
+
1
)
pos
=
torch
.
arange
(
num_logits
,
dtype
=
torch
.
int32
,
device
=
device
)
idx_mapping
=
torch
.
arange
(
num_trials
,
dtype
=
torch
.
int32
,
device
=
device
)
expanded_idx_mapping
=
torch
.
arange
(
num_trials
,
dtype
=
torch
.
int32
,
device
=
device
).
repeat_interleave
(
K
+
1
)
expanded_local_pos
=
torch
.
arange
(
K
+
1
,
dtype
=
torch
.
int32
,
device
=
device
).
repeat
(
num_trials
)
temp_tensor
=
torch
.
full
(
(
num_trials
,),
temperature
,
dtype
=
torch
.
float32
,
device
=
device
)
seed
=
torch
.
arange
(
num_trials
,
dtype
=
torch
.
int64
,
device
=
device
)
return
dict
(
target_logits
=
target_logits
,
draft_logits
=
draft_logits
,
draft_sampled
=
draft_sampled
,
cu_num_logits
=
cu_num_logits
,
pos
=
pos
,
idx_mapping
=
idx_mapping
,
expanded_idx_mapping
=
expanded_idx_mapping
,
expanded_local_pos
=
expanded_local_pos
,
temperature
=
temp_tensor
,
seed
=
seed
,
)
def
_assert_distribution_match
(
sampled_tokens
:
torch
.
Tensor
,
target_probs
:
torch
.
Tensor
,
device
:
str
,
label
:
str
=
""
,
min_expected
:
float
=
5.0
,
):
"""
Assert sampled tokens match the target distribution via a
chi-squared goodness-of-fit test. This is done by computing
observed vs expected token counts (target_probs * num_samples),
then checking that the chi-squared statistic is below a conservative
threshold. The threshold is set at df + 10*sqrt(2*df), which
corresponds to ~10 sigma under the chi-squared distribution's
normal approximation, effectively disallowing false positives.
NOTE: Tokens with expected count < min_expected are merged into
a single "other" bin to minimize chi-squared noise.
"""
num_samples
=
sampled_tokens
.
shape
[
0
]
vocab_size
=
target_probs
.
shape
[
0
]
observed
=
torch
.
zeros
(
vocab_size
,
device
=
device
,
dtype
=
torch
.
float32
)
observed
.
scatter_add_
(
0
,
sampled_tokens
,
torch
.
ones
(
num_samples
,
device
=
device
))
expected
=
target_probs
*
num_samples
sufficient
=
expected
>=
min_expected
obs_main
=
observed
[
sufficient
]
exp_main
=
expected
[
sufficient
]
obs_other
=
observed
[
~
sufficient
].
sum
().
unsqueeze
(
0
)
exp_other
=
expected
[
~
sufficient
].
sum
().
unsqueeze
(
0
)
if
exp_other
.
item
()
>=
min_expected
:
obs_all
=
torch
.
cat
([
obs_main
,
obs_other
])
exp_all
=
torch
.
cat
([
exp_main
,
exp_other
])
else
:
obs_all
=
obs_main
exp_all
=
exp_main
chi2
=
((
obs_all
-
exp_all
)
**
2
/
exp_all
).
sum
().
item
()
df
=
obs_all
.
shape
[
0
]
-
1
if
df
<
1
:
# All samples were merged into < 2 bins, which is too
# few to evaluate.
return
threshold
=
df
+
10
*
math
.
sqrt
(
2
*
df
)
prefix
=
f
"[
{
label
}
] "
if
label
else
""
assert
chi2
<
threshold
,
(
f
"
{
prefix
}
Chi-squared test failed: chi2=
{
chi2
:.
1
f
}
, "
f
"df=
{
df
}
, threshold=
{
threshold
:.
1
f
}
. "
f
"Output distribution does not match target distribution."
)
@
pytest
.
mark
.
parametrize
(
"num_speculative_steps,temperature"
,
[
(
1
,
0.6
),
(
3
,
0.6
),
(
1
,
1.0
),
(
3
,
1.0
),
],
)
def
test_stochastic_rejection_sample
(
num_speculative_steps
:
int
,
temperature
:
float
):
"""
Verify that rejection sampling produces the target distribution.
This is done by simulating many independent trials of speculative
decoding (from a fixed target and draft distribution). We then
run rejection sample on all of the trials (requests), and verify
that the sampled tokens at every position follow the target
distribution p(x).
"""
torch
.
manual_seed
(
42
)
device
=
"cuda"
num_trials
=
10
*
VOCAB_SIZE
target_logits_1d
=
torch
.
randn
(
VOCAB_SIZE
,
device
=
device
,
dtype
=
torch
.
float32
)
draft_logits_1d
=
torch
.
randn
(
VOCAB_SIZE
,
device
=
device
,
dtype
=
torch
.
float32
)
if
temperature
>
0
:
target_logits_1d
/=
temperature
draft_logits_1d
/=
temperature
inputs
=
_build_rejection_sample_inputs
(
target_logits_1d
,
draft_logits_1d
,
num_speculative_steps
,
temperature
=
temperature
,
num_trials
=
num_trials
,
)
sampled
,
num_sampled
=
probabilistic_rejection_sample
(
**
inputs
,
num_speculative_steps
=
num_speculative_steps
)
target_probs
=
torch
.
softmax
(
target_logits_1d
,
dim
=
0
)
for
pos
in
range
(
num_speculative_steps
+
1
):
accepted_mask
=
num_sampled
>=
pos
+
1
_assert_distribution_match
(
sampled
[
accepted_mask
,
pos
],
target_probs
,
device
,
label
=
f
"position
{
pos
}
"
)
@
pytest
.
mark
.
parametrize
(
"num_speculative_steps"
,
[
1
,
3
])
def
test_greedy_rejection_sample
(
num_speculative_steps
:
int
):
"""
Verify that greedy (temperature=0) always outputs the target argmax
at every accepted position.
"""
torch
.
manual_seed
(
42
)
device
=
"cuda"
num_trials
=
10
*
VOCAB_SIZE
target_logits_1d
=
torch
.
randn
(
VOCAB_SIZE
,
device
=
device
,
dtype
=
torch
.
float32
)
draft_logits_1d
=
torch
.
randn
(
VOCAB_SIZE
,
device
=
device
,
dtype
=
torch
.
float32
)
inputs
=
_build_rejection_sample_inputs
(
target_logits_1d
,
draft_logits_1d
,
num_speculative_steps
,
temperature
=
0.0
,
num_trials
=
num_trials
,
)
sampled
,
num_sampled
=
probabilistic_rejection_sample
(
**
inputs
,
num_speculative_steps
=
num_speculative_steps
)
target_argmax
=
target_logits_1d
.
argmax
().
item
()
steps
=
torch
.
arange
(
num_speculative_steps
+
1
,
device
=
device
).
unsqueeze
(
0
)
accepted_mask
=
steps
<
num_sampled
.
unsqueeze
(
1
)
assert
(
sampled
[
accepted_mask
]
==
target_argmax
).
all
(),
(
"Greedy sampling produced tokens that are not the target argmax"
)
vllm/v1/worker/gpu/sample/gumbel.py
View file @
5daf6227
...
@@ -65,36 +65,20 @@ def tl_rand64(seed, offset, includes_zero: tl.constexpr):
...
@@ -65,36 +65,20 @@ def tl_rand64(seed, offset, includes_zero: tl.constexpr):
@
triton
.
jit
@
triton
.
jit
def
_gumbel_sample_kernel
(
def
gumbel_block_argmax
(
local_argmax_ptr
,
logits
,
local_argmax_stride
,
block
,
local_max_ptr
,
mask
,
local_max_stride
,
token_idx
,
processed_logits_ptr
,
processed_logits_stride
,
logits_ptr
,
logits_stride
,
expanded_idx_mapping_ptr
,
expanded_idx_mapping_ptr
,
temp_ptr
,
seeds_ptr
,
seeds_ptr
,
pos_ptr
,
pos_ptr
,
temp_ptr
,
processed_logits_ptr
,
vocab_size
,
processed_logits_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
):
):
token_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
token_idx
)
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
token_idx
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
token_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
logits
=
logits
.
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
temp
!=
0.0
and
APPLY_TEMPERATURE
:
if
temp
!=
0.0
and
APPLY_TEMPERATURE
:
# Apply temperature.
# Apply temperature.
...
@@ -102,8 +86,8 @@ def _gumbel_sample_kernel(
...
@@ -102,8 +86,8 @@ def _gumbel_sample_kernel(
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits
=
logits
/
temp
logits
=
logits
/
temp
# Store the temperature-applied logits.
if
processed_logits_ptr
is
not
None
:
if
processed_logits_ptr
is
not
None
:
# Store the temperature-applied logits.
tl
.
store
(
tl
.
store
(
processed_logits_ptr
+
req_state_idx
*
processed_logits_stride
+
block
,
processed_logits_ptr
+
req_state_idx
*
processed_logits_stride
+
block
,
logits
,
logits
,
...
@@ -126,6 +110,51 @@ def _gumbel_sample_kernel(
...
@@ -126,6 +110,51 @@ def _gumbel_sample_kernel(
logits
=
tl
.
where
(
mask
,
logits
+
gumbel_noise
,
float
(
"-inf"
))
logits
=
tl
.
where
(
mask
,
logits
+
gumbel_noise
,
float
(
"-inf"
))
value
,
idx
=
tl
.
max
(
logits
,
axis
=
0
,
return_indices
=
True
)
value
,
idx
=
tl
.
max
(
logits
,
axis
=
0
,
return_indices
=
True
)
return
value
,
idx
@
triton
.
jit
def
_gumbel_sample_kernel
(
local_argmax_ptr
,
local_argmax_stride
,
local_max_ptr
,
local_max_stride
,
processed_logits_ptr
,
processed_logits_stride
,
logits_ptr
,
logits_stride
,
expanded_idx_mapping_ptr
,
seeds_ptr
,
pos_ptr
,
temp_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
):
token_idx
=
tl
.
program_id
(
0
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
token_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
logits
=
logits
.
to
(
tl
.
float32
)
value
,
idx
=
gumbel_block_argmax
(
logits
,
block
,
mask
,
token_idx
,
expanded_idx_mapping_ptr
,
temp_ptr
,
seeds_ptr
,
pos_ptr
,
processed_logits_ptr
,
processed_logits_stride
,
APPLY_TEMPERATURE
=
APPLY_TEMPERATURE
,
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
tl
.
store
(
local_argmax_ptr
+
token_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_argmax_ptr
+
token_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_max_ptr
+
token_idx
*
local_max_stride
+
block_idx
,
value
)
tl
.
store
(
local_max_ptr
+
token_idx
*
local_max_stride
+
block_idx
,
value
)
...
...
vllm/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py
0 → 100644
View file @
5daf6227
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_block_argmax
,
tl_rand64
@
triton
.
jit
def
_compute_block_max_and_sumexp
(
logits
):
block_max
=
tl
.
max
(
logits
,
axis
=
0
)
block_sumexp
=
tl
.
where
(
block_max
>
float
(
"-inf"
),
tl
.
sum
(
tl
.
exp
(
logits
-
block_max
)),
0.0
,
)
return
block_max
,
block_sumexp
@
triton
.
jit
def
_compute_global_lse
(
local_max_ptr
,
local_max_stride
,
local_sumexp_ptr
,
local_sumexp_stride
,
logit_idx
,
vocab_num_blocks
,
PADDED_VOCAB_NUM_BLOCKS
:
tl
.
constexpr
,
):
blocks
=
tl
.
arange
(
0
,
PADDED_VOCAB_NUM_BLOCKS
)
blocks_mask
=
blocks
<
vocab_num_blocks
maxes
=
tl
.
load
(
local_max_ptr
+
logit_idx
*
local_max_stride
+
blocks
,
mask
=
blocks_mask
,
other
=
float
(
"-inf"
),
)
sumexps
=
tl
.
load
(
local_sumexp_ptr
+
logit_idx
*
local_sumexp_stride
+
blocks
,
mask
=
blocks_mask
,
other
=
0.0
,
)
global_max
=
tl
.
max
(
maxes
,
axis
=
0
)
global_lse
=
global_max
+
tl
.
log
(
tl
.
sum
(
sumexps
*
tl
.
exp
(
maxes
-
global_max
)))
return
global_lse
@
triton
.
jit
def
_compute_block_max_and_sumexp_kernel
(
# [num_logits, num_blocks]
target_local_argmax_ptr
,
target_local_argmax_stride
,
# [num_logits, num_blocks]
target_local_max_ptr
,
target_local_max_stride
,
# [num_logits, num_blocks]
target_local_sumexp_ptr
,
target_local_sumexp_stride
,
# [num_logits, num_blocks]
draft_local_max_ptr
,
draft_local_max_stride
,
# [num_logits, num_blocks]
draft_local_sumexp_ptr
,
draft_local_sumexp_stride
,
# [num_logits, V]
target_logits_ptr
,
target_logits_stride
,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr
,
draft_logits_stride_0
,
draft_logits_stride_1
,
# [num_logits]
expanded_idx_mapping_ptr
,
# [num_logits]
expanded_local_pos_ptr
,
# [max_num_reqs]
temp_ptr
,
vocab_size
,
num_speculative_steps
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
logit_idx
=
tl
.
program_id
(
0
)
draft_step_idx
=
tl
.
load
(
expanded_local_pos_ptr
+
logit_idx
)
if
draft_step_idx
>=
num_speculative_steps
:
# Bonus token. Max/argmax and summed exponentials are not needed.
return
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
logit_idx
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
block_idx
=
tl
.
program_id
(
1
)
block_offsets
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block_offsets
<
vocab_size
if
temp
==
0.0
:
# Greedy sampling. Only the target max/argmax are needed.
target_logits
=
tl
.
load
(
target_logits_ptr
+
logit_idx
*
target_logits_stride
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
value
,
idx
=
tl
.
max
(
target_logits
,
axis
=
0
,
return_indices
=
True
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
tl
.
store
(
target_local_argmax_ptr
+
logit_idx
*
target_local_argmax_stride
+
block_idx
,
token_id
,
)
tl
.
store
(
target_local_max_ptr
+
logit_idx
*
target_local_max_stride
+
block_idx
,
value
,
)
else
:
# Get local draft max and summed exponentials.
draft_logits
=
tl
.
load
(
draft_logits_ptr
+
req_state_idx
*
draft_logits_stride_0
+
draft_step_idx
*
draft_logits_stride_1
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
draft_max
,
draft_sumexp
=
_compute_block_max_and_sumexp
(
draft_logits
)
tl
.
store
(
draft_local_max_ptr
+
logit_idx
*
draft_local_max_stride
+
block_idx
,
draft_max
,
)
tl
.
store
(
draft_local_sumexp_ptr
+
logit_idx
*
draft_local_sumexp_stride
+
block_idx
,
draft_sumexp
,
)
# Get local target max and summed exponentials.
target_logits
=
tl
.
load
(
target_logits_ptr
+
logit_idx
*
target_logits_stride
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
target_max
,
target_sumexp
=
_compute_block_max_and_sumexp
(
target_logits
)
tl
.
store
(
target_local_max_ptr
+
logit_idx
*
target_local_max_stride
+
block_idx
,
target_max
,
)
tl
.
store
(
target_local_sumexp_ptr
+
logit_idx
*
target_local_sumexp_stride
+
block_idx
,
target_sumexp
,
)
@
triton
.
jit
def
_probabilistic_rejection_kernel
(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr
,
sampled_stride
,
# [num_reqs]
rejected_steps_ptr
,
# [num_reqs]
target_rejected_logsumexp_ptr
,
# [num_reqs]
draft_rejected_logsumexp_ptr
,
# [num_logits, V]
target_logits_ptr
,
target_logits_stride
,
# [num_logits, num_blocks]
target_local_argmax_ptr
,
target_local_argmax_stride
,
# [num_logits, num_blocks]
target_local_max_ptr
,
target_local_max_stride
,
# [num_logits, num_blocks]
target_local_sumexp_ptr
,
target_local_sumexp_stride
,
# [num_logits]
draft_sampled_ptr
,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr
,
draft_logits_stride_0
,
draft_logits_stride_1
,
# [num_logits, num_blocks]
draft_local_max_ptr
,
draft_local_max_stride
,
# [num_logits, num_blocks]
draft_local_sumexp_ptr
,
draft_local_sumexp_stride
,
# [num_reqs + 1]
cu_num_logits_ptr
,
# [num_reqs]
idx_mapping_ptr
,
# [max_num_reqs]
temp_ptr
,
# [max_num_reqs]
seed_ptr
,
# [num_logits]
pos_ptr
,
vocab_num_blocks
,
PADDED_VOCAB_NUM_BLOCKS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
seed
=
tl
.
load
(
seed_ptr
+
req_state_idx
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
rejected_step
=
0
target_lse
=
0.0
draft_lse
=
0.0
accepted
=
True
for
i
in
range
(
num_tokens
-
1
):
if
accepted
:
logit_idx
=
start_idx
+
i
draft_sampled
=
tl
.
load
(
draft_sampled_ptr
+
logit_idx
+
1
)
if
temp
==
0.0
:
# Greedy sampling. Accept IFF draft matches target argmax.
# NOTE: Target argmax is stored directly so that resampling
# can be skipped upon rejection.
target_blocks
=
tl
.
arange
(
0
,
PADDED_VOCAB_NUM_BLOCKS
)
target_blocks_mask
=
target_blocks
<
vocab_num_blocks
target_local_max
=
tl
.
load
(
target_local_max_ptr
+
logit_idx
*
target_local_max_stride
+
target_blocks
,
mask
=
target_blocks_mask
,
other
=
float
(
"-inf"
),
)
max_target_block_idx
=
tl
.
argmax
(
target_local_max
,
axis
=
0
)
target_argmax
=
tl
.
load
(
target_local_argmax_ptr
+
logit_idx
*
target_local_argmax_stride
+
max_target_block_idx
)
accepted
&=
target_argmax
==
draft_sampled
tl
.
store
(
sampled_ptr
+
req_idx
*
sampled_stride
+
i
,
target_argmax
)
else
:
target_logit
=
tl
.
load
(
target_logits_ptr
+
logit_idx
*
target_logits_stride
+
draft_sampled
).
to
(
tl
.
float32
)
draft_logit
=
tl
.
load
(
draft_logits_ptr
+
req_state_idx
*
draft_logits_stride_0
+
i
*
draft_logits_stride_1
+
draft_sampled
).
to
(
tl
.
float32
)
target_lse
=
_compute_global_lse
(
target_local_max_ptr
,
target_local_max_stride
,
target_local_sumexp_ptr
,
target_local_sumexp_stride
,
logit_idx
,
vocab_num_blocks
,
PADDED_VOCAB_NUM_BLOCKS
,
)
draft_lse
=
_compute_global_lse
(
draft_local_max_ptr
,
draft_local_max_stride
,
draft_local_sumexp_ptr
,
draft_local_sumexp_stride
,
logit_idx
,
vocab_num_blocks
,
PADDED_VOCAB_NUM_BLOCKS
,
)
target_log_prob
=
target_logit
-
target_lse
draft_log_prob
=
draft_logit
-
draft_lse
pos
=
tl
.
load
(
pos_ptr
+
logit_idx
)
u
=
tl_rand64
(
seed
,
pos
,
includes_zero
=
False
)
# Probability ratio test: p(x) > u * q(x)
# Equivalent log form: log_p(x) > log(u) + log_q(x)
accepted
&=
target_log_prob
>
tl
.
log
(
u
)
+
draft_log_prob
tl
.
store
(
sampled_ptr
+
req_idx
*
sampled_stride
+
i
,
draft_sampled
)
rejected_step
+=
accepted
tl
.
store
(
rejected_steps_ptr
+
req_idx
,
rejected_step
)
tl
.
store
(
target_rejected_logsumexp_ptr
+
req_idx
,
target_lse
)
tl
.
store
(
draft_rejected_logsumexp_ptr
+
req_idx
,
draft_lse
)
@
triton
.
jit
def
_resample_kernel
(
# [num_reqs, num_blocks]
resampled_local_argmax_ptr
,
resampled_local_argmax_stride
,
# [num_reqs, num_blocks]
resampled_local_max_ptr
,
resampled_local_max_stride
,
# [num_logits, V]
target_logits_ptr
,
target_logits_stride
,
# [num_reqs]
target_rejected_logsumexp_ptr
,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr
,
draft_logits_stride_0
,
draft_logits_stride_1
,
# [num_reqs]
draft_rejected_logsumexp_ptr
,
# [num_reqs]
rejected_step_ptr
,
# [num_reqs + 1]
cu_num_logits_ptr
,
# [num_logits]
expanded_idx_mapping_ptr
,
# [max_num_reqs]
temp_ptr
,
# [max_num_reqs]
seed_ptr
,
# [num_logits]
pos_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
resample_idx
=
tl
.
load
(
rejected_step_ptr
+
req_idx
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
resample_token_idx
=
start_idx
+
resample_idx
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
resample_token_idx
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
is_bonus
=
resample_token_idx
==
end_idx
-
1
if
temp
==
0.0
and
not
is_bonus
:
# Greedy + non-bonus token. No resampling needed because
# the target argmax is already in the sampled tensor.
return
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
# Compute the residual logits to resample the rejected token
# from. In the case of no rejections (bonus token), we directly
# use the target logits.
if
is_bonus
:
residual_logits
=
tl
.
load
(
target_logits_ptr
+
resample_token_idx
*
target_logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
else
:
target_logits
=
tl
.
load
(
target_logits_ptr
+
resample_token_idx
*
target_logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
draft_logits
=
tl
.
load
(
draft_logits_ptr
+
req_state_idx
*
draft_logits_stride_0
+
resample_idx
*
draft_logits_stride_1
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
target_lse
=
tl
.
load
(
target_rejected_logsumexp_ptr
+
req_idx
)
draft_lse
=
tl
.
load
(
draft_rejected_logsumexp_ptr
+
req_idx
)
target_log_probs
=
target_logits
-
target_lse
draft_log_probs
=
draft_logits
-
draft_lse
# Compute the residual: max(p(x) - q(x), 0)
# Equivalent log form: log(max(exp(log_p(x)) - exp(log_q(x)), 0))
# The more numerically stable form is:
# log(max(exp(a) - exp(b), 0)) = a + log(max(1 - exp(b - a), 0))
ratio
=
tl
.
exp
(
draft_log_probs
-
target_log_probs
)
residual_logits
=
tl
.
where
(
ratio
<
1.0
,
target_log_probs
+
tl
.
log
(
1
-
ratio
),
float
(
"-inf"
),
).
to
(
tl
.
float32
)
# Resample the rejected/bonus token.
value
,
idx
=
gumbel_block_argmax
(
residual_logits
,
block
,
mask
,
resample_token_idx
,
expanded_idx_mapping_ptr
,
temp_ptr
,
seed_ptr
,
pos_ptr
,
None
,
0
,
APPLY_TEMPERATURE
=
False
,
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
tl
.
store
(
resampled_local_argmax_ptr
+
req_idx
*
resampled_local_argmax_stride
+
block_idx
,
token_id
,
)
tl
.
store
(
resampled_local_max_ptr
+
req_idx
*
resampled_local_max_stride
+
block_idx
,
value
,
)
@
triton
.
jit
def
_insert_resampled_kernel
(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr
,
sampled_stride
,
# [num_reqs]
num_sampled_ptr
,
# [num_reqs, num_blocks]
resampled_local_argmax_ptr
,
resampled_local_argmax_stride
,
# [num_reqs, num_blocks]
resampled_local_max_ptr
,
resampled_local_max_stride
,
resample_num_blocks
,
# [num_reqs + 1]
cu_num_logits_ptr
,
# [num_reqs]
expanded_idx_mapping_ptr
,
# [max_num_reqs]
temp_ptr
,
PADDED_RESAMPLE_NUM_BLOCKS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
num_sampled
=
tl
.
load
(
num_sampled_ptr
+
req_idx
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
resample_token_idx
=
start_idx
+
num_sampled
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
resample_token_idx
)
# Increment the number of sampled tokens.
tl
.
store
(
num_sampled_ptr
+
req_idx
,
num_sampled
+
1
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
is_bonus
=
resample_token_idx
==
end_idx
-
1
if
temp
==
0.0
and
not
is_bonus
:
# Greedy + non-bonus token. The target argmax is already
# in the sampled tensor.
return
# Insert the resampled token.
block
=
tl
.
arange
(
0
,
PADDED_RESAMPLE_NUM_BLOCKS
)
mask
=
block
<
resample_num_blocks
resampled_local_max
=
tl
.
load
(
resampled_local_max_ptr
+
req_idx
*
resampled_local_max_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
resampled_max_block_idx
=
tl
.
argmax
(
resampled_local_max
,
axis
=
0
)
resampled
=
tl
.
load
(
resampled_local_argmax_ptr
+
req_idx
*
resampled_local_argmax_stride
+
resampled_max_block_idx
,
)
tl
.
store
(
sampled_ptr
+
req_idx
*
sampled_stride
+
num_sampled
,
resampled
,
)
def
probabilistic_rejection_sample
(
# [num_logits, V]
target_logits
:
torch
.
Tensor
,
# [max_num_reqs, num_speculative_steps, V]
draft_logits
:
torch
.
Tensor
,
# [num_logits]
draft_sampled
:
torch
.
Tensor
,
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
,
# [num_logits]
pos
:
torch
.
Tensor
,
# [num_reqs]
idx_mapping
:
torch
.
Tensor
,
# [num_logits]
expanded_idx_mapping
:
torch
.
Tensor
,
# [num_logits]
expanded_local_pos
:
torch
.
Tensor
,
# [max_num_reqs]
temperature
:
torch
.
Tensor
,
# [max_num_reqs]
seed
:
torch
.
Tensor
,
num_speculative_steps
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
num_logits
,
vocab_size
=
target_logits
.
shape
# Gather draft logits, compute target argmax for greedy, and
# compute per-block LSE and max for non-greedy requests.
VOCAB_BLOCK_SIZE
=
8192
vocab_num_blocks
=
triton
.
cdiv
(
vocab_size
,
VOCAB_BLOCK_SIZE
)
padded_vocab_num_blocks
=
triton
.
next_power_of_2
(
vocab_num_blocks
)
target_local_argmax
=
target_logits
.
new_empty
(
num_logits
,
vocab_num_blocks
,
dtype
=
torch
.
int64
)
target_local_max
=
target_logits
.
new_empty
(
num_logits
,
vocab_num_blocks
,
dtype
=
torch
.
float32
)
target_local_sumexp
=
target_logits
.
new_empty
(
num_logits
,
vocab_num_blocks
,
dtype
=
torch
.
float32
)
draft_local_max
=
target_logits
.
new_empty
(
num_logits
,
vocab_num_blocks
,
dtype
=
torch
.
float32
)
draft_local_sumexp
=
target_logits
.
new_empty
(
num_logits
,
vocab_num_blocks
,
dtype
=
torch
.
float32
)
_compute_block_max_and_sumexp_kernel
[(
num_logits
,
vocab_num_blocks
)](
target_local_argmax
,
target_local_argmax
.
stride
(
0
),
target_local_max
,
target_local_max
.
stride
(
0
),
target_local_sumexp
,
target_local_sumexp
.
stride
(
0
),
draft_local_max
,
draft_local_max
.
stride
(
0
),
draft_local_sumexp
,
draft_local_sumexp
.
stride
(
0
),
target_logits
,
target_logits
.
stride
(
0
),
draft_logits
,
draft_logits
.
stride
(
0
),
draft_logits
.
stride
(
1
),
expanded_idx_mapping
,
expanded_local_pos
,
temperature
,
vocab_size
,
num_speculative_steps
,
BLOCK_SIZE
=
VOCAB_BLOCK_SIZE
,
)
# Sample up until the first rejected/bonus token, and store
# the step.
sampled
=
draft_sampled
.
new_empty
(
num_reqs
,
num_speculative_steps
+
1
,
dtype
=
torch
.
int64
)
num_sampled
=
sampled
.
new_empty
(
num_reqs
)
target_rejected_logsumexp
=
target_logits
.
new_empty
(
num_reqs
,
dtype
=
torch
.
float32
)
draft_rejected_logsumexp
=
target_logits
.
new_empty
(
num_reqs
,
dtype
=
torch
.
float32
)
_probabilistic_rejection_kernel
[(
num_reqs
,)](
sampled
,
sampled
.
stride
(
0
),
num_sampled
,
target_rejected_logsumexp
,
draft_rejected_logsumexp
,
target_logits
,
target_logits
.
stride
(
0
),
target_local_argmax
,
target_local_argmax
.
stride
(
0
),
target_local_max
,
target_local_max
.
stride
(
0
),
target_local_sumexp
,
target_local_sumexp
.
stride
(
0
),
draft_sampled
,
draft_logits
,
draft_logits
.
stride
(
0
),
draft_logits
.
stride
(
1
),
draft_local_max
,
draft_local_max
.
stride
(
0
),
draft_local_sumexp
,
draft_local_sumexp
.
stride
(
0
),
cu_num_logits
,
idx_mapping
,
temperature
,
seed
,
pos
,
vocab_num_blocks
,
PADDED_VOCAB_NUM_BLOCKS
=
padded_vocab_num_blocks
,
num_warps
=
1
,
)
# Resample the rejected/bonus tokens.
RESAMPLE_BLOCK_SIZE
=
1024
resample_num_blocks
=
triton
.
cdiv
(
vocab_size
,
RESAMPLE_BLOCK_SIZE
)
padded_resample_num_blocks
=
triton
.
next_power_of_2
(
resample_num_blocks
)
resampled_local_argmax
=
target_logits
.
new_empty
(
num_reqs
,
resample_num_blocks
,
dtype
=
torch
.
int64
)
resampled_local_max
=
target_logits
.
new_empty
(
num_reqs
,
resample_num_blocks
,
dtype
=
torch
.
float64
)
_resample_kernel
[(
num_reqs
,
resample_num_blocks
)](
resampled_local_argmax
,
resampled_local_argmax
.
stride
(
0
),
resampled_local_max
,
resampled_local_max
.
stride
(
0
),
target_logits
,
target_logits
.
stride
(
0
),
target_rejected_logsumexp
,
draft_logits
,
draft_logits
.
stride
(
0
),
draft_logits
.
stride
(
1
),
draft_rejected_logsumexp
,
num_sampled
,
cu_num_logits
,
expanded_idx_mapping
,
temperature
,
seed
,
pos
,
vocab_size
,
BLOCK_SIZE
=
RESAMPLE_BLOCK_SIZE
,
)
# Insert the resampled tokens into the output sampled.
_insert_resampled_kernel
[(
num_reqs
,)](
sampled
,
sampled
.
stride
(
0
),
num_sampled
,
resampled_local_argmax
,
resampled_local_argmax
.
stride
(
0
),
resampled_local_max
,
resampled_local_max
.
stride
(
0
),
resample_num_blocks
,
cu_num_logits
,
expanded_idx_mapping
,
temperature
,
PADDED_RESAMPLE_NUM_BLOCKS
=
padded_resample_num_blocks
,
)
return
sampled
,
num_sampled
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
View file @
5daf6227
...
@@ -7,11 +7,13 @@ from vllm.triton_utils import tl, triton
...
@@ -7,11 +7,13 @@ from vllm.triton_utils import tl, triton
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.metrics.logits
import
get_num_nans
from
vllm.v1.worker.gpu.metrics.logits
import
get_num_nans
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
,
tl_rand64
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.sample.states
import
NO_LOGPROBS
from
vllm.v1.worker.gpu.sample.states
import
NO_LOGPROBS
from
vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils
import
(
probabilistic_rejection_sample
,
)
from
vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils
import
(
from
vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils
import
(
compute_synthetic_rejection_sampler_params
,
compute_synthetic_rejection_sampler_params
,
synthetic_rejection_sample
,
synthetic_rejection_sample
,
...
@@ -75,357 +77,6 @@ def strict_rejection_sample(
...
@@ -75,357 +77,6 @@ def strict_rejection_sample(
return
sampled
,
num_sampled
return
sampled
,
num_sampled
@
triton
.
jit
def
_gather_draft_logits_and_target_argmax_kernel
(
local_target_argmax_ptr
,
local_target_argmax_stride
,
local_target_max_ptr
,
local_target_max_stride
,
# [num_logits, V]
out_draft_logits_ptr
,
out_draft_logits_stride
,
# [num_logits, V]
target_logits_ptr
,
target_logits_stride
,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr
,
draft_logits_stride_0
,
draft_logits_stride_1
,
# [num_logits]
expanded_idx_mapping_ptr
,
# [num_logits]
expanded_local_pos_ptr
,
# [max_num_reqs]
temp_ptr
,
vocab_size
,
num_speculative_steps
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
logit_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
logit_idx
)
draft_step_idx
=
tl
.
load
(
expanded_local_pos_ptr
+
logit_idx
)
block_idx
=
tl
.
program_id
(
1
)
block_offsets
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block_offsets
<
vocab_size
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
temp
==
0.0
:
# Greedy sampling. Get the target logits argmax.
target_logits
=
tl
.
load
(
target_logits_ptr
+
logit_idx
*
target_logits_stride
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
value
,
idx
=
tl
.
max
(
target_logits
,
axis
=
0
,
return_indices
=
True
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
tl
.
store
(
local_target_argmax_ptr
+
logit_idx
*
local_target_argmax_stride
+
block_idx
,
token_id
,
)
tl
.
store
(
local_target_max_ptr
+
logit_idx
*
local_target_max_stride
+
block_idx
,
value
,
)
elif
draft_step_idx
<
num_speculative_steps
:
draft_logits
=
tl
.
load
(
draft_logits_ptr
+
req_state_idx
*
draft_logits_stride_0
+
draft_step_idx
*
draft_logits_stride_1
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
).
to
(
tl
.
float32
)
tl
.
store
(
out_draft_logits_ptr
+
logit_idx
*
out_draft_logits_stride
+
block_offsets
,
draft_logits
,
mask
=
mask
,
)
@
triton
.
jit
def
_probabilistic_rejection_kernel
(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr
,
sampled_stride
,
# [num_reqs]
rejected_steps_ptr
,
# [num_reqs]
rejected_pos_ptr
,
# [num_logits]
draft_sampled_ptr
,
# [num_logits, V]
target_probs_ptr
,
target_probs_stride
,
# [num_logits, V]
draft_probs_ptr
,
draft_probs_stride
,
# [num_logits, num_blocks]
local_target_argmax_ptr
,
local_target_argmax_stride
,
# [num_logits, num_blocks]
local_target_max_ptr
,
local_target_max_stride
,
# [num_reqs + 1]
cu_num_logits_ptr
,
# [num_logits]
pos_ptr
,
# [num_reqs]
idx_mapping_ptr
,
# [max_num_reqs]
temp_ptr
,
# [max_num_reqs]
seeds_ptr
,
NUM_BLOCKS
:
tl
.
constexpr
,
PADDED_NUM_BLOCKS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
num_tokens
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
-
start_idx
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
seed
=
tl
.
load
(
seeds_ptr
+
req_state_idx
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
rejected_step
=
0
accepted
=
True
for
i
in
range
(
num_tokens
-
1
):
if
accepted
:
logit_idx
=
start_idx
+
i
draft_sampled
=
tl
.
load
(
draft_sampled_ptr
+
logit_idx
+
1
)
if
temp
==
0.0
:
# Greedy sampling. Only accept the sampled draft token if
# it exactly matches the target argmax.
block_offsets
=
tl
.
arange
(
0
,
PADDED_NUM_BLOCKS
)
block_mask
=
block_offsets
<
NUM_BLOCKS
local_max
=
tl
.
load
(
local_target_max_ptr
+
logit_idx
*
local_target_max_stride
+
block_offsets
,
mask
=
block_mask
,
other
=
float
(
"-inf"
),
)
max_block
=
tl
.
argmax
(
local_max
,
axis
=
0
)
target_argmax
=
tl
.
load
(
local_target_argmax_ptr
+
logit_idx
*
local_target_argmax_stride
+
max_block
)
accepted
&=
target_argmax
==
draft_sampled
else
:
target_prob
=
tl
.
load
(
target_probs_ptr
+
logit_idx
*
target_probs_stride
+
draft_sampled
).
to
(
tl
.
float64
)
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
logit_idx
*
draft_probs_stride
+
draft_sampled
).
to
(
tl
.
float64
)
pos
=
tl
.
load
(
pos_ptr
+
logit_idx
)
u
=
tl_rand64
(
seed
,
pos
,
includes_zero
=
False
)
accepted
&=
target_prob
>
u
*
draft_prob
tl
.
store
(
sampled_ptr
+
req_idx
*
sampled_stride
+
i
,
draft_sampled
)
rejected_step
+=
accepted
tl
.
store
(
rejected_steps_ptr
+
req_idx
,
rejected_step
)
pos_val
=
tl
.
load
(
pos_ptr
+
start_idx
+
rejected_step
)
tl
.
store
(
rejected_pos_ptr
+
req_idx
,
pos_val
)
@
triton
.
jit
def
_compute_residual_logits_kernel
(
# [num_reqs, V]
residual_logits_ptr
,
residual_logits_stride
,
# [num_logits, V]
target_probs_ptr
,
target_probs_stride
,
# [num_logits, V]
draft_probs_ptr
,
draft_probs_stride
,
# [num_logits, V]
target_logits_ptr
,
target_logits_stride
,
# [num_reqs]
rejected_step_ptr
,
# [num_reqs + 1]
cu_num_logits_ptr
,
# [num_reqs]
idx_mapping_ptr
,
# [max_num_reqs]
temp_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
block_idx
=
tl
.
program_id
(
1
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
rejected_logit_idx
=
start_idx
+
tl
.
load
(
rejected_step_ptr
+
req_idx
)
temp
=
tl
.
load
(
temp_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
block_offsets
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block_offsets
<
vocab_size
if
temp
==
0.0
or
(
rejected_logit_idx
==
end_idx
-
1
):
# Greedy sampling / bonus token. In either case, use the
# target logits directly to reduce numerical error.
residual_logits
=
tl
.
load
(
target_logits_ptr
+
rejected_logit_idx
*
target_logits_stride
+
block_offsets
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
else
:
target_probs
=
tl
.
load
(
target_probs_ptr
+
rejected_logit_idx
*
target_probs_stride
+
block_offsets
,
mask
=
mask
,
other
=
0.0
,
)
draft_probs
=
tl
.
load
(
draft_probs_ptr
+
rejected_logit_idx
*
draft_probs_stride
+
block_offsets
,
mask
=
mask
,
other
=
0.0
,
)
residual_probs
=
tl
.
maximum
(
target_probs
-
draft_probs
,
0.0
)
residual_logits
=
tl
.
log
(
residual_probs
)
tl
.
store
(
residual_logits_ptr
+
req_idx
*
residual_logits_stride
+
block_offsets
,
residual_logits
,
mask
=
mask
,
)
def
probabilistic_rejection_sample
(
# [num_logits, V]
target_logits
:
torch
.
Tensor
,
# [max_num_reqs, num_speculative_steps, V]
draft_logits
:
torch
.
Tensor
,
# [num_logits]
draft_sampled
:
torch
.
Tensor
,
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
,
# [num_logits]
pos
:
torch
.
Tensor
,
# [num_reqs]
idx_mapping
:
torch
.
Tensor
,
# [num_logits]
expanded_idx_mapping
:
torch
.
Tensor
,
# [num_logits]
expanded_local_pos
:
torch
.
Tensor
,
# [max_num_reqs]
temperature
:
torch
.
Tensor
,
# [max_num_reqs]
seed
:
torch
.
Tensor
,
num_speculative_steps
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
num_logits
,
vocab_size
=
target_logits
.
shape
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
# Gather draft logits and target argmax for greedy sampling.
gathered_draft_logits
=
target_logits
.
new_empty
(
target_logits
.
shape
)
local_target_argmax
=
target_logits
.
new_empty
(
num_logits
,
num_blocks
,
dtype
=
torch
.
int64
)
local_target_max
=
target_logits
.
new_empty
(
num_logits
,
num_blocks
,
dtype
=
torch
.
float32
)
_gather_draft_logits_and_target_argmax_kernel
[(
num_logits
,
num_blocks
)](
local_target_argmax
,
local_target_argmax
.
stride
(
0
),
local_target_max
,
local_target_max
.
stride
(
0
),
gathered_draft_logits
,
gathered_draft_logits
.
stride
(
0
),
target_logits
,
target_logits
.
stride
(
0
),
draft_logits
,
draft_logits
.
stride
(
0
),
draft_logits
.
stride
(
1
),
expanded_idx_mapping
,
expanded_local_pos
,
temperature
,
vocab_size
,
num_speculative_steps
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Compute target and draft probs.
target_probs
=
torch
.
softmax
(
target_logits
,
dim
=-
1
)
draft_probs
=
torch
.
softmax
(
gathered_draft_logits
,
dim
=-
1
)
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
sampled
=
draft_sampled
.
new_empty
(
num_reqs
,
num_speculative_steps
+
1
,
dtype
=
torch
.
int64
)
# [num_reqs]
rejected_steps
=
sampled
.
new_empty
(
num_reqs
)
# [num_reqs]
rejected_pos
=
pos
.
new_empty
(
num_reqs
)
_probabilistic_rejection_kernel
[(
num_reqs
,)](
sampled
,
sampled
.
stride
(
0
),
rejected_steps
,
rejected_pos
,
draft_sampled
,
target_probs
,
target_probs
.
stride
(
0
),
draft_probs
,
draft_probs
.
stride
(
0
),
local_target_argmax
,
local_target_argmax
.
stride
(
0
),
local_target_max
,
local_target_max
.
stride
(
0
),
cu_num_logits
,
pos
,
idx_mapping
,
temperature
,
seed
,
num_warps
=
1
,
NUM_BLOCKS
=
num_blocks
,
PADDED_NUM_BLOCKS
=
triton
.
next_power_of_2
(
num_blocks
),
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits
=
target_logits
.
new_empty
(
num_reqs
,
vocab_size
)
_compute_residual_logits_kernel
[(
num_reqs
,
num_blocks
)](
residual_logits
,
residual_logits
.
stride
(
0
),
target_probs
,
target_probs
.
stride
(
0
),
draft_probs
,
draft_probs
.
stride
(
0
),
target_logits
,
target_logits
.
stride
(
0
),
rejected_steps
,
cu_num_logits
,
idx_mapping
,
temperature
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Gumbel sample tokens from the residual distribution.
resampled
=
gumbel_sample
(
residual_logits
,
idx_mapping
,
temperature
,
seed
,
rejected_pos
,
apply_temperature
=
False
,
)
sampled
.
scatter_
(
1
,
rejected_steps
.
unsqueeze
(
1
),
resampled
.
unsqueeze
(
1
))
return
sampled
,
rejected_steps
+
1
@
triton
.
jit
@
triton
.
jit
def
_flatten_sampled_kernel
(
def
_flatten_sampled_kernel
(
# [num_logits]
# [num_logits]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment