Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e9af6ba6
Unverified
Commit
e9af6ba6
authored
Nov 21, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 21, 2025
Browse files
[Model Runner V2] Optimize Gumbel Sampling Kernel (#29210)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
c6fa3895
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
50 deletions
+43
-50
vllm/v1/worker/gpu/sampler.py
vllm/v1/worker/gpu/sampler.py
+43
-50
No files found.
vllm/v1/worker/gpu/sampler.py
View file @
e9af6ba6
...
...
@@ -3,10 +3,9 @@
from
collections.abc
import
Callable
import
torch
import
triton
import
triton.language
as
tl
from
vllm.config.model
import
LogprobsMode
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
...
...
@@ -78,7 +77,10 @@ class Sampler:
@
triton
.
jit
def
_gumbel_sample_kernel
(
sampled_ptr
,
local_argmax_ptr
,
local_argmax_stride
,
local_max_ptr
,
local_max_stride
,
logits_ptr
,
logits_stride
,
seeds_ptr
,
...
...
@@ -88,40 +90,21 @@ def _gumbel_sample_kernel(
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
is_greedy
=
tl
.
load
(
is_greedy_ptr
+
req_idx
)
if
is_greedy
:
# Greedy sampling. Don't apply gumbel noise.
max_val
=
float
(
"-inf"
)
max_idx
=
0
for
i
in
range
(
0
,
vocab_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
idx
=
tl
.
argmax
(
logits
,
axis
=
0
)
value
=
tl
.
max
(
logits
,
axis
=
0
)
is_greater
=
value
>
max_val
max_val
=
tl
.
where
(
is_greater
,
value
,
max_val
)
max_idx
=
tl
.
where
(
is_greater
,
i
+
idx
,
max_idx
)
tl
.
store
(
sampled_ptr
+
req_idx
,
max_idx
)
return
# Random sampling.
# Calculate gumbel seed.
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
pos
=
tl
.
load
(
pos_ptr
+
req_idx
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
max_val
=
float
(
"-inf"
)
max_idx
=
0
for
i
in
range
(
0
,
vocab_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
is_greedy
=
tl
.
load
(
is_greedy_ptr
+
req_idx
)
if
not
is_greedy
:
# Calculate the seed for gumbel noise.
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
pos
=
tl
.
load
(
pos_ptr
+
req_idx
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
# Generate gumbel noise.
r
=
tl
.
rand
(
gumbel_seed
,
block
).
to
(
tl
.
float64
)
...
...
@@ -129,16 +112,13 @@ def _gumbel_sample_kernel(
gumbel_noise
=
gumbel_noise
.
to
(
tl
.
float32
)
# Apply gumbel noise.
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
)
logits
=
tl
.
where
(
mask
,
logits
+
gumbel_noise
,
float
(
"-inf"
))
# Argmax to get the sampled token.
idx
=
tl
.
argmax
(
logits
,
axis
=
0
)
value
=
tl
.
max
(
logits
,
axis
=
0
)
is_greater
=
value
>
max_val
max_val
=
tl
.
where
(
is_greater
,
value
,
max_val
)
max_idx
=
tl
.
where
(
is_greater
,
i
+
idx
,
max_idx
)
tl
.
store
(
sampled_ptr
+
req_idx
,
max_idx
)
idx
=
tl
.
argmax
(
logits
,
axis
=
0
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
value
=
tl
.
max
(
logits
,
axis
=
0
)
tl
.
store
(
local_argmax_ptr
+
req_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_max_ptr
+
req_idx
*
local_max_stride
+
block_idx
,
value
)
def
gumbel_sample
(
...
...
@@ -148,23 +128,36 @@ def gumbel_sample(
pos
:
torch
.
Tensor
,
# [num_reqs]
)
->
torch
.
Tensor
:
num_reqs
,
vocab_size
=
logits
.
shape
# NOTE(woosuk): Use int64 for later indexing.
sampled
=
torch
.
empty
(
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
local_argmax
=
torch
.
empty
(
num_reqs
,
num_blocks
,
dtype
=
torch
.
int64
,
device
=
logits
.
device
,
)
_gumbel_sample_kernel
[(
num_reqs
,)](
sampled
,
local_max
=
torch
.
empty
(
num_reqs
,
num_blocks
,
dtype
=
torch
.
float32
,
device
=
logits
.
device
,
)
_gumbel_sample_kernel
[(
num_reqs
,
num_blocks
)](
local_argmax
,
local_argmax
.
stride
(
0
),
local_max
,
local_max
.
stride
(
0
),
logits
,
logits
.
stride
(
0
),
seed
,
pos
,
is_greedy
,
vocab_size
,
num_warps
=
8
,
BLOCK_SIZE
=
16384
,
# type: ignore
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx
=
local_max
.
argmax
(
dim
=-
1
,
keepdim
=
True
)
sampled
=
local_argmax
.
gather
(
dim
=-
1
,
index
=
max_block_idx
).
view
(
-
1
)
return
sampled
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment