Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5e9d511
Unverified
Commit
a5e9d511
authored
Mar 22, 2026
by
Woosuk Kwon
Committed by
GitHub
Mar 22, 2026
Browse files
[MRV2] Use FP64 for Gumbel noise (#37798)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
c058ff44
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
8 deletions
+24
-8
vllm/v1/worker/gpu/sample/gumbel.py
vllm/v1/worker/gpu/sample/gumbel.py
+20
-4
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
+4
-4
No files found.
vllm/v1/worker/gpu/sample/gumbel.py
View file @
a5e9d511
...
...
@@ -49,6 +49,21 @@ def apply_temperature(
)
@
triton
.
jit
def
tl_rand64
(
seed
,
offset
,
includes_zero
:
tl
.
constexpr
):
lo
,
hi
,
_
,
_
=
tl
.
randint4x
(
seed
,
offset
)
lo
=
lo
.
to
(
tl
.
uint32
,
bitcast
=
True
).
to
(
tl
.
uint64
)
hi
=
hi
.
to
(
tl
.
uint32
,
bitcast
=
True
).
to
(
tl
.
uint64
)
r
=
(
hi
<<
32
)
|
lo
# 1 / 2**64
scale
=
5.421010862427522170037e-20
u
=
r
.
to
(
tl
.
float64
)
*
scale
if
not
includes_zero
:
u
=
tl
.
maximum
(
u
,
2.2250738585072014e-308
)
# float64 tiny
return
u
@
triton
.
jit
def
_gumbel_sample_kernel
(
local_argmax_ptr
,
...
...
@@ -95,15 +110,16 @@ def _gumbel_sample_kernel(
mask
=
mask
,
)
logits
=
logits
.
to
(
tl
.
float64
)
if
temp
!=
0.0
:
# Calculate the seed for gumbel noise.
seed
=
tl
.
load
(
seeds_ptr
+
req_state_idx
)
pos
=
tl
.
load
(
pos_ptr
+
token_idx
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
#
Generate gumbel noise in FP32.
u
=
tl
.
rand
(
gumbel_seed
,
block
)
u
=
tl
.
maximum
(
u
,
1e-7
)
#
tl.rand returns fp32, so build a true fp64 uniform from 64 random
# bits before applying the double-log transform.
u
=
tl
_rand64
(
gumbel_seed
,
block
,
includes_zero
=
False
)
gumbel_noise
=
-
tl
.
log
(
-
tl
.
log
(
u
))
# Apply gumbel noise.
...
...
@@ -128,7 +144,7 @@ def gumbel_sample(
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
local_argmax
=
logits
.
new_empty
(
num_tokens
,
num_blocks
,
dtype
=
torch
.
int64
)
local_max
=
logits
.
new_empty
(
num_tokens
,
num_blocks
,
dtype
=
torch
.
float
32
)
local_max
=
logits
.
new_empty
(
num_tokens
,
num_blocks
,
dtype
=
torch
.
float
64
)
_gumbel_sample_kernel
[(
num_tokens
,
num_blocks
)](
local_argmax
,
local_argmax
.
stride
(
0
),
...
...
vllm/v1/worker/gpu/spec_decode/rejection_sampler.py
View file @
a5e9d511
...
...
@@ -6,7 +6,7 @@ from vllm.triton_utils import tl, triton
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.metrics.logits
import
get_num_nans
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
,
tl_rand64
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
...
...
@@ -211,12 +211,12 @@ def _probabilistic_rejection_kernel(
else
:
target_prob
=
tl
.
load
(
target_probs_ptr
+
logit_idx
*
target_probs_stride
+
draft_sampled
)
)
.
to
(
tl
.
float64
)
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
logit_idx
*
draft_probs_stride
+
draft_sampled
)
)
.
to
(
tl
.
float64
)
pos
=
tl
.
load
(
pos_ptr
+
logit_idx
)
u
=
tl
.
sum
(
tl
.
rand
(
seed
,
pos
+
tl
.
arange
(
0
,
1
))
)
u
=
tl
_
rand
64
(
seed
,
pos
,
includes_zero
=
False
)
accepted
&=
target_prob
>
u
*
draft_prob
tl
.
store
(
sampled_ptr
+
req_idx
*
sampled_stride
+
i
,
draft_sampled
)
rejected_step
+=
accepted
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment