Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a7165fd
Unverified
Commit
0a7165fd
authored
Mar 03, 2026
by
Andy Lo
Committed by
GitHub
Mar 02, 2026
Browse files
[ModelRunnerV2] Rename sampler functions and variables for clarity (#35459)
Signed-off-by:
Andy Lo
<
andy@mistral.ai
>
parent
6521ccf2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
99 additions
and
91 deletions
+99
-91
vllm/v1/worker/gpu/sample/bad_words.py
vllm/v1/worker/gpu/sample/bad_words.py
+9
-9
vllm/v1/worker/gpu/sample/gumbel.py
vllm/v1/worker/gpu/sample/gumbel.py
+24
-24
vllm/v1/worker/gpu/sample/logit_bias.py
vllm/v1/worker/gpu/sample/logit_bias.py
+16
-16
vllm/v1/worker/gpu/sample/min_p.py
vllm/v1/worker/gpu/sample/min_p.py
+14
-10
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+15
-15
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+14
-10
vllm/v1/worker/gpu/sample/states.py
vllm/v1/worker/gpu/sample/states.py
+7
-7
No files found.
vllm/v1/worker/gpu/sample/bad_words.py
View file @
0a7165fd
...
@@ -72,7 +72,7 @@ class BadWordsState:
...
@@ -72,7 +72,7 @@ class BadWordsState:
def
apply_bad_words
(
def
apply_bad_words
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
...
@@ -84,7 +84,7 @@ class BadWordsState:
...
@@ -84,7 +84,7 @@ class BadWordsState:
apply_bad_words
(
apply_bad_words
(
logits
,
logits
,
idx_mapping
,
expanded_
idx_mapping
,
self
.
bad_word_token_ids
.
gpu
,
self
.
bad_word_token_ids
.
gpu
,
self
.
bad_word_offsets
.
gpu
,
self
.
bad_word_offsets
.
gpu
,
self
.
num_bad_words
.
gpu
,
self
.
num_bad_words
.
gpu
,
...
@@ -114,17 +114,17 @@ def _bad_words_kernel(
...
@@ -114,17 +114,17 @@ def _bad_words_kernel(
input_ids_ptr
,
input_ids_ptr
,
expanded_local_pos_ptr
,
expanded_local_pos_ptr
,
):
):
logit
_idx
=
tl
.
program_id
(
0
)
token
_idx
=
tl
.
program_id
(
0
)
bw_idx
=
tl
.
program_id
(
1
)
bw_idx
=
tl
.
program_id
(
1
)
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
logit
_idx
)
req_state_idx
=
tl
.
load
(
expanded_idx_mapping_ptr
+
token
_idx
)
num_bad_words
=
tl
.
load
(
num_bad_words_ptr
+
req_state_idx
)
num_bad_words
=
tl
.
load
(
num_bad_words_ptr
+
req_state_idx
)
if
bw_idx
>=
num_bad_words
:
if
bw_idx
>=
num_bad_words
:
return
return
pos
=
tl
.
load
(
expanded_local_pos_ptr
+
logit
_idx
)
pos
=
tl
.
load
(
expanded_local_pos_ptr
+
token
_idx
)
cur_req_first_pos
=
logit
_idx
-
pos
cur_req_first_pos
=
token
_idx
-
pos
prompt_len
=
tl
.
load
(
prompt_len_ptr
+
req_state_idx
)
prompt_len
=
tl
.
load
(
prompt_len_ptr
+
req_state_idx
)
total_len
=
tl
.
load
(
total_len_ptr
+
req_state_idx
)
total_len
=
tl
.
load
(
total_len_ptr
+
req_state_idx
)
...
@@ -159,7 +159,7 @@ def _bad_words_kernel(
...
@@ -159,7 +159,7 @@ def _bad_words_kernel(
match
=
match
&
(
expected
==
actual
)
match
=
match
&
(
expected
==
actual
)
if
match
:
if
match
:
tl
.
store
(
logits_ptr
+
logit
_idx
*
logits_stride
+
last_token
,
-
float
(
"inf"
))
tl
.
store
(
logits_ptr
+
token
_idx
*
logits_stride
+
last_token
,
-
float
(
"inf"
))
def
apply_bad_words
(
def
apply_bad_words
(
...
@@ -175,8 +175,8 @@ def apply_bad_words(
...
@@ -175,8 +175,8 @@ def apply_bad_words(
expanded_local_pos
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
max_num_bad_words
:
int
,
max_num_bad_words
:
int
,
)
->
None
:
)
->
None
:
total_
num_tokens
=
logits
.
shape
[
0
]
num_tokens
=
logits
.
shape
[
0
]
_bad_words_kernel
[(
total_
num_tokens
,
max_num_bad_words
)](
_bad_words_kernel
[(
num_tokens
,
max_num_bad_words
)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
expanded_idx_mapping
,
expanded_idx_mapping
,
...
...
vllm/v1/worker/gpu/sample/gumbel.py
View file @
0a7165fd
...
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
...
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def
_temperature_kernel
(
def
_temperature_kernel
(
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
idx_mapping_ptr
,
expanded_
idx_mapping_ptr
,
temperature_ptr
,
temperature_ptr
,
vocab_size
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
batch
_idx
=
tl
.
program_id
(
0
)
token
_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch
_idx
)
req_state_idx
=
tl
.
load
(
expanded_
idx_mapping_ptr
+
token
_idx
)
temperature
=
tl
.
load
(
temperature_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
temperature
=
tl
.
load
(
temperature_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
temperature
==
0.0
or
temperature
==
1.0
:
if
temperature
==
0.0
or
temperature
==
1.0
:
# Early return to avoid loading logits.
# Early return to avoid loading logits.
...
@@ -25,24 +25,24 @@ def _temperature_kernel(
...
@@ -25,24 +25,24 @@ def _temperature_kernel(
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
batch
_idx
*
logits_stride
+
block
,
mask
=
mask
)
logits
=
tl
.
load
(
logits_ptr
+
token
_idx
*
logits_stride
+
block
,
mask
=
mask
)
logits
=
logits
.
to
(
tl
.
float32
)
logits
=
logits
.
to
(
tl
.
float32
)
logits
=
logits
/
temperature
logits
=
logits
/
temperature
tl
.
store
(
logits_ptr
+
batch
_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
tl
.
store
(
logits_ptr
+
token
_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_temperature
(
def
apply_temperature
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
num_
req
s
,
vocab_size
=
logits
.
shape
num_
token
s
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
8192
BLOCK_SIZE
=
8192
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
_temperature_kernel
[(
num_
req
s
,
num_blocks
)](
_temperature_kernel
[(
num_
token
s
,
num_blocks
)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
idx_mapping
,
expanded_
idx_mapping
,
temperature
,
temperature
,
vocab_size
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
BLOCK_SIZE
,
...
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
...
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
local_max_stride
,
local_max_stride
,
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
idx_mapping_ptr
,
expanded_
idx_mapping_ptr
,
seeds_ptr
,
seeds_ptr
,
pos_ptr
,
pos_ptr
,
temp_ptr
,
temp_ptr
,
...
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
...
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
):
):
batch
_idx
=
tl
.
program_id
(
0
)
token
_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch
_idx
)
req_state_idx
=
tl
.
load
(
expanded_
idx_mapping_ptr
+
token
_idx
)
block_idx
=
tl
.
program_id
(
1
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits
=
tl
.
load
(
logits_ptr
+
batch
_idx
*
logits_stride
+
block
,
logits_ptr
+
token
_idx
*
logits_stride
+
block
,
mask
=
mask
,
mask
=
mask
,
other
=
float
(
"-inf"
),
other
=
float
(
"-inf"
),
)
)
...
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
...
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
if
temp
!=
0.0
:
if
temp
!=
0.0
:
# Calculate the seed for gumbel noise.
# Calculate the seed for gumbel noise.
seed
=
tl
.
load
(
seeds_ptr
+
req_state_idx
)
seed
=
tl
.
load
(
seeds_ptr
+
req_state_idx
)
pos
=
tl
.
load
(
pos_ptr
+
batch
_idx
)
pos
=
tl
.
load
(
pos_ptr
+
token
_idx
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
# Generate gumbel noise in FP32.
# Generate gumbel noise in FP32.
...
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
...
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
value
,
idx
=
tl
.
max
(
logits
,
axis
=
0
,
return_indices
=
True
)
value
,
idx
=
tl
.
max
(
logits
,
axis
=
0
,
return_indices
=
True
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
tl
.
store
(
local_argmax_ptr
+
batch
_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_argmax_ptr
+
token
_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_max_ptr
+
batch
_idx
*
local_max_stride
+
block_idx
,
value
)
tl
.
store
(
local_max_ptr
+
token
_idx
*
local_max_stride
+
block_idx
,
value
)
def
gumbel_sample
(
def
gumbel_sample
(
logits
:
torch
.
Tensor
,
# [num_
req
s, vocab_size]
logits
:
torch
.
Tensor
,
# [num_
token
s, vocab_size]
idx_mapping
:
torch
.
Tensor
,
# [
max_num_req
s]
expanded_
idx_mapping
:
torch
.
Tensor
,
# [
num_token
s]
temperature
:
torch
.
Tensor
,
# [max_num_reqs]
temperature
:
torch
.
Tensor
,
# [max_num_reqs]
seed
:
torch
.
Tensor
,
# [max_num_reqs]
seed
:
torch
.
Tensor
,
# [max_num_reqs]
pos
:
torch
.
Tensor
,
# [num_
req
s]
pos
:
torch
.
Tensor
,
# [num_
token
s]
apply_temperature
:
bool
,
apply_temperature
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_
req
s
,
vocab_size
=
logits
.
shape
num_
token
s
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
1024
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
local_argmax
=
torch
.
empty
(
local_argmax
=
torch
.
empty
(
num_
req
s
,
num_
token
s
,
num_blocks
,
num_blocks
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
logits
.
device
,
device
=
logits
.
device
,
)
)
local_max
=
torch
.
empty
(
local_max
=
torch
.
empty
(
num_
req
s
,
num_
token
s
,
num_blocks
,
num_blocks
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
logits
.
device
,
device
=
logits
.
device
,
)
)
_gumbel_sample_kernel
[(
num_
req
s
,
num_blocks
)](
_gumbel_sample_kernel
[(
num_
token
s
,
num_blocks
)](
local_argmax
,
local_argmax
,
local_argmax
.
stride
(
0
),
local_argmax
.
stride
(
0
),
local_max
,
local_max
,
local_max
.
stride
(
0
),
local_max
.
stride
(
0
),
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
idx_mapping
,
expanded_
idx_mapping
,
seed
,
seed
,
pos
,
pos
,
temperature
,
temperature
,
...
...
vllm/v1/worker/gpu/sample/logit_bias.py
View file @
0a7165fd
...
@@ -121,7 +121,7 @@ class LogitBiasState:
...
@@ -121,7 +121,7 @@ class LogitBiasState:
def
apply_logit_bias
(
def
apply_logit_bias
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
...
@@ -131,7 +131,7 @@ class LogitBiasState:
...
@@ -131,7 +131,7 @@ class LogitBiasState:
apply_logit_bias
(
apply_logit_bias
(
logits
,
logits
,
idx_mapping
,
expanded_
idx_mapping
,
pos
,
pos
,
self
.
num_allowed_token_ids
.
gpu
,
self
.
num_allowed_token_ids
.
gpu
,
self
.
allowed_token_ids
.
gpu
,
self
.
allowed_token_ids
.
gpu
,
...
@@ -149,7 +149,7 @@ def _bias_kernel(
...
@@ -149,7 +149,7 @@ def _bias_kernel(
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
vocab_size
,
vocab_size
,
idx_mapping_ptr
,
expanded_
idx_mapping_ptr
,
# Allowed token IDs.
# Allowed token IDs.
num_allowed_token_ids_ptr
,
num_allowed_token_ids_ptr
,
allowed_token_ids_ptr
,
allowed_token_ids_ptr
,
...
@@ -169,8 +169,8 @@ def _bias_kernel(
...
@@ -169,8 +169,8 @@ def _bias_kernel(
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
LOGITS_BLOCK_SIZE
:
tl
.
constexpr
,
LOGITS_BLOCK_SIZE
:
tl
.
constexpr
,
):
):
batch
_idx
=
tl
.
program_id
(
0
)
token
_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch
_idx
)
req_state_idx
=
tl
.
load
(
expanded_
idx_mapping_ptr
+
token
_idx
)
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
...
@@ -186,21 +186,21 @@ def _bias_kernel(
...
@@ -186,21 +186,21 @@ def _bias_kernel(
mask
=
mask
,
mask
=
mask
,
)
)
logits
=
tl
.
load
(
logits
=
tl
.
load
(
logits_ptr
+
batch
_idx
*
logits_stride
+
allowed_token_ids
,
mask
=
mask
logits_ptr
+
token
_idx
*
logits_stride
+
allowed_token_ids
,
mask
=
mask
)
)
# Set logits to -inf for all tokens.
# Set logits to -inf for all tokens.
for
i
in
range
(
0
,
vocab_size
,
LOGITS_BLOCK_SIZE
):
for
i
in
range
(
0
,
vocab_size
,
LOGITS_BLOCK_SIZE
):
offset
=
i
+
tl
.
arange
(
0
,
LOGITS_BLOCK_SIZE
)
offset
=
i
+
tl
.
arange
(
0
,
LOGITS_BLOCK_SIZE
)
tl
.
store
(
tl
.
store
(
logits_ptr
+
batch
_idx
*
logits_stride
+
offset
,
logits_ptr
+
token
_idx
*
logits_stride
+
offset
,
-
float
(
"inf"
),
-
float
(
"inf"
),
mask
=
offset
<
vocab_size
,
mask
=
offset
<
vocab_size
,
)
)
# Restore logits for allowed token IDs.
# Restore logits for allowed token IDs.
tl
.
store
(
tl
.
store
(
logits_ptr
+
batch
_idx
*
logits_stride
+
allowed_token_ids
,
logits_ptr
+
token
_idx
*
logits_stride
+
allowed_token_ids
,
logits
,
logits
,
mask
=
mask
,
mask
=
mask
,
)
)
...
@@ -214,13 +214,13 @@ def _bias_kernel(
...
@@ -214,13 +214,13 @@ def _bias_kernel(
mask
=
mask
,
mask
=
mask
,
)
)
bias
=
tl
.
load
(
bias_ptr
+
req_state_idx
*
bias_stride
+
block
,
mask
=
mask
)
bias
=
tl
.
load
(
bias_ptr
+
req_state_idx
*
bias_stride
+
block
,
mask
=
mask
)
logits
=
tl
.
load
(
logits_ptr
+
batch
_idx
*
logits_stride
+
token_ids
,
mask
=
mask
)
logits
=
tl
.
load
(
logits_ptr
+
token
_idx
*
logits_stride
+
token_ids
,
mask
=
mask
)
logits
+=
bias
logits
+=
bias
tl
.
store
(
logits_ptr
+
batch
_idx
*
logits_stride
+
token_ids
,
logits
,
mask
=
mask
)
tl
.
store
(
logits_ptr
+
token
_idx
*
logits_stride
+
token_ids
,
logits
,
mask
=
mask
)
# Apply min tokens.
# Apply min tokens.
num_stop_token_ids
=
tl
.
load
(
num_stop_token_ids_ptr
+
req_state_idx
)
num_stop_token_ids
=
tl
.
load
(
num_stop_token_ids_ptr
+
req_state_idx
)
pos
=
tl
.
load
(
pos_ptr
+
batch
_idx
)
pos
=
tl
.
load
(
pos_ptr
+
token
_idx
)
min_len
=
tl
.
load
(
min_lens_ptr
+
req_state_idx
)
min_len
=
tl
.
load
(
min_lens_ptr
+
req_state_idx
)
if
num_stop_token_ids
>
0
and
pos
<
min_len
:
if
num_stop_token_ids
>
0
and
pos
<
min_len
:
mask
=
block
<
num_stop_token_ids
mask
=
block
<
num_stop_token_ids
...
@@ -229,7 +229,7 @@ def _bias_kernel(
...
@@ -229,7 +229,7 @@ def _bias_kernel(
mask
=
mask
,
mask
=
mask
,
)
)
tl
.
store
(
tl
.
store
(
logits_ptr
+
batch
_idx
*
logits_stride
+
stop_token_ids
,
logits_ptr
+
token
_idx
*
logits_stride
+
stop_token_ids
,
-
float
(
"inf"
),
-
float
(
"inf"
),
mask
=
mask
,
mask
=
mask
,
)
)
...
@@ -237,7 +237,7 @@ def _bias_kernel(
...
@@ -237,7 +237,7 @@ def _bias_kernel(
def
apply_logit_bias
(
def
apply_logit_bias
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
num_allowed_token_ids
:
torch
.
Tensor
,
num_allowed_token_ids
:
torch
.
Tensor
,
allowed_token_ids
:
torch
.
Tensor
,
allowed_token_ids
:
torch
.
Tensor
,
...
@@ -248,7 +248,7 @@ def apply_logit_bias(
...
@@ -248,7 +248,7 @@ def apply_logit_bias(
num_stop_token_ids
:
torch
.
Tensor
,
num_stop_token_ids
:
torch
.
Tensor
,
stop_token_ids
:
torch
.
Tensor
,
stop_token_ids
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
num_
req
s
,
vocab_size
=
logits
.
shape
num_
token
s
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
triton
.
next_power_of_2
(
BLOCK_SIZE
=
triton
.
next_power_of_2
(
max
(
max
(
allowed_token_ids
.
shape
[
-
1
],
allowed_token_ids
.
shape
[
-
1
],
...
@@ -257,11 +257,11 @@ def apply_logit_bias(
...
@@ -257,11 +257,11 @@ def apply_logit_bias(
)
)
)
)
LOGITS_BLOCK_SIZE
=
8192
LOGITS_BLOCK_SIZE
=
8192
_bias_kernel
[(
num_
req
s
,)](
_bias_kernel
[(
num_
token
s
,)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
vocab_size
,
vocab_size
,
idx_mapping
,
expanded_
idx_mapping
,
num_allowed_token_ids
,
num_allowed_token_ids
,
allowed_token_ids
,
allowed_token_ids
,
allowed_token_ids
.
stride
(
0
),
allowed_token_ids
.
stride
(
0
),
...
...
vllm/v1/worker/gpu/sample/min_p.py
View file @
0a7165fd
...
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
...
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def
_min_p_kernel
(
def
_min_p_kernel
(
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
idx_mapping_ptr
,
expanded_
idx_mapping_ptr
,
min_p_ptr
,
min_p_ptr
,
vocab_size
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
req
_idx
=
tl
.
program_id
(
0
)
token
_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req
_idx
)
req_state_idx
=
tl
.
load
(
expanded_
idx_mapping_ptr
+
token
_idx
)
min_p
=
tl
.
load
(
min_p_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
min_p
=
tl
.
load
(
min_p_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
min_p
==
0.0
:
if
min_p
==
0.0
:
return
return
...
@@ -25,7 +25,9 @@ def _min_p_kernel(
...
@@ -25,7 +25,9 @@ def _min_p_kernel(
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
)
logits_ptr
+
token_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
)
max_val
=
tl
.
max
(
tl
.
maximum
(
logits
,
max_val
))
max_val
=
tl
.
max
(
tl
.
maximum
(
logits
,
max_val
))
max_val
=
max_val
.
to
(
tl
.
float32
)
# type: ignore
max_val
=
max_val
.
to
(
tl
.
float32
)
# type: ignore
...
@@ -35,21 +37,23 @@ def _min_p_kernel(
...
@@ -35,21 +37,23 @@ def _min_p_kernel(
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
)
logits_ptr
+
token_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
)
logits
=
tl
.
where
(
logits
<
threshold
,
float
(
"-inf"
),
logits
)
logits
=
tl
.
where
(
logits
<
threshold
,
float
(
"-inf"
),
logits
)
tl
.
store
(
logits_ptr
+
req
_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
tl
.
store
(
logits_ptr
+
token
_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_min_p
(
def
apply_min_p
(
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
min_p
:
torch
.
Tensor
logits
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
min_p
:
torch
.
Tensor
)
->
None
:
)
->
None
:
num_
req
s
,
vocab_size
=
logits
.
shape
num_
token
s
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
1024
BLOCK_SIZE
=
1024
_min_p_kernel
[(
num_
req
s
,)](
_min_p_kernel
[(
num_
token
s
,)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
idx_mapping
,
expanded_
idx_mapping
,
min_p
,
min_p
,
vocab_size
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
BLOCK_SIZE
,
...
...
vllm/v1/worker/gpu/sample/penalties.py
View file @
0a7165fd
...
@@ -82,7 +82,7 @@ class PenaltiesState:
...
@@ -82,7 +82,7 @@ class PenaltiesState:
def
apply_penalties
(
def
apply_penalties
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
...
@@ -94,7 +94,7 @@ class PenaltiesState:
...
@@ -94,7 +94,7 @@ class PenaltiesState:
apply_penalties
(
apply_penalties
(
logits
,
logits
,
idx_mapping
,
expanded_
idx_mapping
,
input_ids
,
input_ids
,
expanded_local_pos
,
expanded_local_pos
,
self
.
repetition_penalty
.
gpu
,
self
.
repetition_penalty
.
gpu
,
...
@@ -110,7 +110,7 @@ class PenaltiesState:
...
@@ -110,7 +110,7 @@ class PenaltiesState:
def
_penalties_kernel
(
def
_penalties_kernel
(
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
idx_mapping_ptr
,
expanded_
idx_mapping_ptr
,
token_ids_ptr
,
token_ids_ptr
,
expanded_local_pos_ptr
,
expanded_local_pos_ptr
,
repetition_penalty_ptr
,
repetition_penalty_ptr
,
...
@@ -125,7 +125,7 @@ def _penalties_kernel(
...
@@ -125,7 +125,7 @@ def _penalties_kernel(
MAX_SPEC_LEN
:
tl
.
constexpr
,
MAX_SPEC_LEN
:
tl
.
constexpr
,
):
):
token_idx
=
tl
.
program_id
(
0
)
token_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
token_idx
)
req_state_idx
=
tl
.
load
(
expanded_
idx_mapping_ptr
+
token_idx
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
req_state_idx
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
req_state_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
req_state_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
req_state_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
req_state_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
req_state_idx
)
...
@@ -191,7 +191,7 @@ def _penalties_kernel(
...
@@ -191,7 +191,7 @@ def _penalties_kernel(
def
apply_penalties
(
def
apply_penalties
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
repetition_penalty
:
torch
.
Tensor
,
repetition_penalty
:
torch
.
Tensor
,
...
@@ -207,7 +207,7 @@ def apply_penalties(
...
@@ -207,7 +207,7 @@ def apply_penalties(
_penalties_kernel
[(
num_tokens
,
num_blocks
)](
_penalties_kernel
[(
num_tokens
,
num_blocks
)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
idx_mapping
,
expanded_
idx_mapping
,
token_ids
,
token_ids
,
expanded_local_pos
,
expanded_local_pos
,
repetition_penalty
,
repetition_penalty
,
...
@@ -225,7 +225,7 @@ def apply_penalties(
...
@@ -225,7 +225,7 @@ def apply_penalties(
@
triton
.
jit
@
triton
.
jit
def
_bincount_kernel
(
def
_bincount_kernel
(
idx_mapping_ptr
,
expanded_
idx_mapping_ptr
,
all_token_ids_ptr
,
all_token_ids_ptr
,
all_token_ids_stride
,
all_token_ids_stride
,
prompt_len_ptr
,
prompt_len_ptr
,
...
@@ -236,9 +236,9 @@ def _bincount_kernel(
...
@@ -236,9 +236,9 @@ def _bincount_kernel(
output_bin_counts_stride
,
output_bin_counts_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
batch
_idx
=
tl
.
program_id
(
0
)
token
_idx
=
tl
.
program_id
(
0
)
block_idx
=
tl
.
program_id
(
1
)
block_idx
=
tl
.
program_id
(
1
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch
_idx
)
req_state_idx
=
tl
.
load
(
expanded_
idx_mapping_ptr
+
token
_idx
)
prefill_len
=
tl
.
load
(
prefill_len_ptr
+
req_state_idx
)
prefill_len
=
tl
.
load
(
prefill_len_ptr
+
req_state_idx
)
if
block_idx
*
BLOCK_SIZE
>=
prefill_len
:
if
block_idx
*
BLOCK_SIZE
>=
prefill_len
:
...
@@ -276,7 +276,7 @@ def _bincount_kernel(
...
@@ -276,7 +276,7 @@ def _bincount_kernel(
def
bincount
(
def
bincount
(
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
all_token_ids
:
torch
.
Tensor
,
all_token_ids
:
torch
.
Tensor
,
prompt_len
:
torch
.
Tensor
,
prompt_len
:
torch
.
Tensor
,
prefill_len
:
torch
.
Tensor
,
prefill_len
:
torch
.
Tensor
,
...
@@ -284,13 +284,13 @@ def bincount(
...
@@ -284,13 +284,13 @@ def bincount(
output_bin_counts
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
max_prefill_len
:
int
,
max_prefill_len
:
int
,
)
->
None
:
)
->
None
:
prompt_bin_mask
[
idx_mapping
]
=
0
prompt_bin_mask
[
expanded_
idx_mapping
]
=
0
output_bin_counts
[
idx_mapping
]
=
0
output_bin_counts
[
expanded_
idx_mapping
]
=
0
num_
reqs
=
idx_mapping
.
shape
[
0
]
num_
tokens
=
expanded_
idx_mapping
.
shape
[
0
]
BLOCK_SIZE
=
1024
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
max_prefill_len
,
BLOCK_SIZE
)
num_blocks
=
triton
.
cdiv
(
max_prefill_len
,
BLOCK_SIZE
)
_bincount_kernel
[(
num_
req
s
,
num_blocks
)](
_bincount_kernel
[(
num_
token
s
,
num_blocks
)](
idx_mapping
,
expanded_
idx_mapping
,
all_token_ids
,
all_token_ids
,
all_token_ids
.
stride
(
0
),
all_token_ids
.
stride
(
0
),
prompt_len
,
prompt_len
,
...
...
vllm/v1/worker/gpu/sample/sampler.py
View file @
0a7165fd
...
@@ -56,7 +56,7 @@ class Sampler:
...
@@ -56,7 +56,7 @@ class Sampler:
def
__call__
(
def
__call__
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
cu_num_logits_np
:
np
.
ndarray
,
cu_num_logits_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
...
@@ -68,7 +68,7 @@ class Sampler:
...
@@ -68,7 +68,7 @@ class Sampler:
num_nans
=
get_num_nans
(
logits
)
if
self
.
compute_nans
else
None
num_nans
=
get_num_nans
(
logits
)
if
self
.
compute_nans
else
None
sampled
,
processed_logits
=
self
.
sample
(
sampled
,
processed_logits
=
self
.
sample
(
logits
,
logits
,
idx_mapping
,
expanded_
idx_mapping
,
idx_mapping_np
,
idx_mapping_np
,
pos
,
pos
,
input_ids
,
input_ids
,
...
@@ -101,7 +101,7 @@ class Sampler:
...
@@ -101,7 +101,7 @@ class Sampler:
def
sample
(
def
sample
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -111,12 +111,14 @@ class Sampler:
...
@@ -111,12 +111,14 @@ class Sampler:
logits
=
torch
.
empty_like
(
logits
,
dtype
=
torch
.
float32
).
copy_
(
logits
)
logits
=
torch
.
empty_like
(
logits
,
dtype
=
torch
.
float32
).
copy_
(
logits
)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self
.
logit_bias_state
.
apply_logit_bias
(
logits
,
idx_mapping
,
idx_mapping_np
,
pos
)
self
.
logit_bias_state
.
apply_logit_bias
(
logits
,
expanded_idx_mapping
,
idx_mapping_np
,
pos
)
# Apply penalties in place.
# Apply penalties in place.
self
.
penalties_state
.
apply_penalties
(
self
.
penalties_state
.
apply_penalties
(
logits
,
logits
,
idx_mapping
,
expanded_
idx_mapping
,
idx_mapping_np
,
idx_mapping_np
,
input_ids
,
input_ids
,
expanded_local_pos
,
expanded_local_pos
,
...
@@ -126,27 +128,29 @@ class Sampler:
...
@@ -126,27 +128,29 @@ class Sampler:
# Apply bad words masking in place.
# Apply bad words masking in place.
self
.
bad_words_state
.
apply_bad_words
(
self
.
bad_words_state
.
apply_bad_words
(
logits
,
logits
,
idx_mapping
,
expanded_
idx_mapping
,
idx_mapping_np
,
idx_mapping_np
,
input_ids
,
input_ids
,
expanded_local_pos
,
expanded_local_pos
,
)
)
# Apply temperature in place.
# Apply temperature in place.
self
.
sampling_states
.
apply_temperature
(
logits
,
idx_mapping
,
idx_mapping_np
)
self
.
sampling_states
.
apply_temperature
(
logits
,
expanded_idx_mapping
,
idx_mapping_np
)
# Apply min_p in place.
# Apply min_p in place.
self
.
sampling_states
.
apply_min_p
(
logits
,
idx_mapping
,
idx_mapping_np
)
self
.
sampling_states
.
apply_min_p
(
logits
,
expanded_
idx_mapping
,
idx_mapping_np
)
# Apply top_k and/or top_p. This might or might not return a new tensor.
# Apply top_k and/or top_p. This might or might not return a new tensor.
logits
=
self
.
sampling_states
.
apply_top_k_top_p
(
logits
=
self
.
sampling_states
.
apply_top_k_top_p
(
logits
,
idx_mapping
,
idx_mapping_np
logits
,
expanded_
idx_mapping
,
idx_mapping_np
)
)
# Sample the next token.
# Sample the next token.
sampled
=
gumbel_sample
(
sampled
=
gumbel_sample
(
logits
,
logits
,
idx_mapping
,
expanded_
idx_mapping
,
self
.
sampling_states
.
temperature
.
gpu
,
self
.
sampling_states
.
temperature
.
gpu
,
self
.
sampling_states
.
seeds
.
gpu
,
self
.
sampling_states
.
seeds
.
gpu
,
pos
,
pos
,
...
...
vllm/v1/worker/gpu/sample/states.py
View file @
0a7165fd
...
@@ -64,7 +64,7 @@ class SamplingStates:
...
@@ -64,7 +64,7 @@ class SamplingStates:
def
apply_temperature
(
def
apply_temperature
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
)
->
None
:
)
->
None
:
temp_np
=
self
.
temperature
.
np
[
idx_mapping_np
]
temp_np
=
self
.
temperature
.
np
[
idx_mapping_np
]
...
@@ -72,23 +72,23 @@ class SamplingStates:
...
@@ -72,23 +72,23 @@ class SamplingStates:
# No request requires temperature. Skip the kernel launch.
# No request requires temperature. Skip the kernel launch.
return
return
apply_temperature
(
logits
,
idx_mapping
,
self
.
temperature
.
gpu
)
apply_temperature
(
logits
,
expanded_
idx_mapping
,
self
.
temperature
.
gpu
)
def
apply_min_p
(
def
apply_min_p
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
)
->
None
:
)
->
None
:
if
np
.
all
(
self
.
min_p
.
np
[
idx_mapping_np
]
==
0.0
):
if
np
.
all
(
self
.
min_p
.
np
[
idx_mapping_np
]
==
0.0
):
# No request uses min_p. Skip the kernel launch.
# No request uses min_p. Skip the kernel launch.
return
return
apply_min_p
(
logits
,
idx_mapping
,
self
.
min_p
.
gpu
)
apply_min_p
(
logits
,
expanded_
idx_mapping
,
self
.
min_p
.
gpu
)
def
apply_top_k_top_p
(
def
apply_top_k_top_p
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
expanded_
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
do_top_k
=
np
.
any
(
self
.
top_k
.
np
[
idx_mapping_np
]
!=
self
.
vocab_size
)
do_top_k
=
np
.
any
(
self
.
top_k
.
np
[
idx_mapping_np
]
!=
self
.
vocab_size
)
...
@@ -96,8 +96,8 @@ class SamplingStates:
...
@@ -96,8 +96,8 @@ class SamplingStates:
if
not
(
do_top_k
or
do_top_p
):
if
not
(
do_top_k
or
do_top_p
):
return
logits
return
logits
top_k
=
self
.
top_k
.
gpu
[
idx_mapping
]
if
do_top_k
else
None
top_k
=
self
.
top_k
.
gpu
[
expanded_
idx_mapping
]
if
do_top_k
else
None
top_p
=
self
.
top_p
.
gpu
[
idx_mapping
]
if
do_top_p
else
None
top_p
=
self
.
top_p
.
gpu
[
expanded_
idx_mapping
]
if
do_top_p
else
None
return
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
return
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
def
max_num_logprobs
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
int
:
def
max_num_logprobs
(
self
,
idx_mapping_np
:
np
.
ndarray
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment