Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb78fb31
Unverified
Commit
bb78fb31
authored
Feb 21, 2025
by
Lu Fang
Committed by
GitHub
Feb 22, 2025
Browse files
[v1] Support allowed_token_ids in v1 Sampler (#13210)
Signed-off-by:
Lu Fang
<
lufang@fb.com
>
parent
8aca27fa
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
168 additions
and
19 deletions
+168
-19
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+1
-0
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+79
-15
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+13
-0
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+14
-0
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+4
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+16
-2
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+41
-2
No files found.
tests/v1/sample/test_rejection_sampler.py
View file @
bb78fb31
...
...
@@ -43,6 +43,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
output_token_ids
=
[],
min_tokens
=
{},
logit_bias
=
[
None
]
*
batch_size
,
allowed_token_ids_mask
=
None
,
)
...
...
tests/v1/sample/test_sampler.py
View file @
bb78fb31
...
...
@@ -57,6 +57,26 @@ def _create_logit_bias(
return
res
def
_create_allowed_token_ids
(
batch_size
:
int
,
vocab_size
:
int
,
num_allowed_token_ids
:
int
,
device
:
torch
.
device
,
)
->
Optional
[
torch
.
Tensor
]:
mask
:
Optional
[
torch
.
Tensor
]
=
None
for
i
in
range
(
batch_size
):
if
i
%
2
==
1
:
continue
if
mask
is
None
:
mask
=
torch
.
zeros
((
batch_size
,
vocab_size
),
dtype
=
torch
.
bool
,
device
=
device
)
start
=
min
(
i
,
vocab_size
-
1
)
end
=
min
(
i
+
num_allowed_token_ids
,
vocab_size
-
1
)
mask
[
i
,
start
:
end
]
=
True
return
mask
def
_create_default_sampling_metadata
(
num_output_tokens
:
int
,
batch_size
:
int
,
...
...
@@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
no_penalties
=
True
,
min_tokens
=
{},
logit_bias
=
[
None
]
*
batch_size
,
allowed_token_ids_mask
=
None
,
)
return
fake_sampling_metadata
...
...
@@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
sampling_metadata
.
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
frequency_penalty
,
torch
.
device
(
device
))
output_token_ids
,
sorted_token_ids_in_output
=
\
_create_weighted_output_token_list
(
batch_size
,
VOCAB_SIZE
)
_create_weighted_output_token_list
(
batch_size
,
VOCAB_SIZE
,
)
sampling_metadata
.
output_token_ids
=
output_token_ids
sampling_metadata
.
no_penalties
=
False
sampler
=
Sampler
()
...
...
@@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
for
batch_idx
in
range
(
batch_size
):
non_penalized_token_id
=
logits
[
batch_idx
].
argmax
().
item
()
penalized_token_id
=
logits
[
batch_idx
].
argmin
().
item
()
distinct_sorted_token_ids_in_output
=
\
sorted_token_ids_in_output
[
batch_idx
]
distinct_sorted_token_ids_in_output
=
sorted_token_ids_in_output
[
batch_idx
]
most_frequent_token_id
=
distinct_sorted_token_ids_in_output
[
len
(
distinct_sorted_token_ids_in_output
)
-
1
]
if
frequency_penalty
>
0
:
...
...
@@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# non-penalized token ID is not present in the output, while the
# most penalized token is the one that occurs most frequently in
# the output.
assert
non_penalized_token_id
\
not
in
distinct_sorted_token_ids_in_output
assert
(
non_penalized_token_id
not
in
distinct_sorted_token_ids_in_output
)
assert
penalized_token_id
==
most_frequent_token_id
elif
frequency_penalty
<
0
:
# If `frequency_penalty` is set to < 0, it indicates
...
...
@@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
# in the output, while the penalized token ID is one that has not
# yet appeared.
assert
non_penalized_token_id
==
most_frequent_token_id
assert
penalized_token_id
\
not
in
distinct_sorted_token_ids_in_output
assert
penalized_token_id
not
in
distinct_sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
...
@@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
# If `repetition_penalty` > 1.0, verify that the non-penalized
# token ID has not been seen before, while the penalized token ID
# exists either in the prompt or the output.
assert
(
non_penalized_token_id
not
in
prompt_tokens
and
\
non_penalized_token_id
not
in
output_tokens
)
assert
(
penalized_token_id
in
prompt_tokens
or
\
penalized_token_id
in
output_tokens
)
assert
(
non_penalized_token_id
not
in
prompt_tokens
and
non_penalized_token_id
not
in
output_tokens
)
assert
(
penalized_token_id
in
prompt_tokens
or
penalized_token_id
in
output_tokens
)
elif
repetition_penalty
<
1.0
:
# If `repetition_penalty` < 1.0, verify that the penalized
# token ID has not been seen before, while the non-penalized
# token ID exists either in the prompt or the output.
assert
(
penalized_token_id
not
in
prompt_tokens
and
\
penalized_token_id
not
in
output_tokens
)
assert
(
non_penalized_token_id
in
prompt_tokens
or
\
non_penalized_token_id
in
output_tokens
)
assert
(
penalized_token_id
not
in
prompt_tokens
and
penalized_token_id
not
in
output_tokens
)
assert
(
non_penalized_token_id
in
prompt_tokens
or
non_penalized_token_id
in
output_tokens
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
...
@@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
1e-2
)
else
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
1e-2
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"num_allowed_token_ids"
,
[
0
,
1
,
2
])
def
test_sampler_allowed_token_ids
(
device
:
str
,
batch_size
:
int
,
num_allowed_token_ids
:
int
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch
.
set_default_device
(
device
)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
mask
=
_create_allowed_token_ids
(
batch_size
=
batch_size
,
vocab_size
=
VOCAB_SIZE
,
num_allowed_token_ids
=
num_allowed_token_ids
,
device
=
device
,
)
sampling_metadata
.
allowed_token_ids_mask
=
mask
sampler
=
Sampler
()
logits
=
sampler
.
apply_allowed_token_ids
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logits_for_req
=
logits
[
batch_idx
]
if
batch_idx
%
2
==
1
:
assert
torch
.
all
(
logits_for_req
!=
-
float
(
"inf"
))
continue
for
token_id
in
range
(
VOCAB_SIZE
):
start
=
min
(
batch_idx
,
VOCAB_SIZE
-
1
)
end
=
min
(
batch_idx
+
num_allowed_token_ids
,
VOCAB_SIZE
-
1
)
if
token_id
>=
start
and
token_id
<
end
:
assert
logits_for_req
[
token_id
]
==
-
float
(
"inf"
),
f
"
{
batch_idx
}
,
{
token_id
}
"
else
:
assert
logits_for_req
[
token_id
]
!=
-
float
(
"inf"
)
tests/v1/worker/test_gpu_input_batch.py
View file @
bb78fb31
...
...
@@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata(
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
min_tokens
=
{}
logit_bias
=
[
None
]
*
num_reqs
allowed_token_ids_mask
=
torch
.
zeros
(
num_reqs
,
VOCAB_SIZE
,
dtype
=
torch
.
bool
,
device
=
device
)
for
req
in
reqs
:
if
req
.
req_id
not
in
req_ids_retained
:
continue
...
...
@@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata(
req
.
sampling_params
.
min_tokens
,
req
.
sampling_params
.
all_stop_token_ids
)
logit_bias
[
index_in_input_batch
]
=
req
.
sampling_params
.
logit_bias
if
req
.
sampling_params
.
allowed_token_ids
:
allowed_token_ids_mask
[
index_in_input_batch
][
req
.
sampling_params
.
allowed_token_ids
]
=
True
return
SamplingMetadata
(
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
...
...
@@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata(
and
all
(
x
==
0
for
x
in
frequency_penalties
)
and
all
(
x
==
1
for
x
in
repetition_penalties
)),
logit_bias
=
logit_bias
,
allowed_token_ids_mask
=
allowed_token_ids_mask
,
)
...
...
@@ -242,3 +251,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert
expected_sampling_metadata
.
no_penalties
==
\
sampling_metadata
.
no_penalties
assert
expected_sampling_metadata
.
logit_bias
==
sampling_metadata
.
logit_bias
if
sampling_metadata
.
allowed_token_ids_mask
:
assert
torch
.
allclose
(
expected_sampling_metadata
.
allowed_token_ids_mask
,
sampling_metadata
.
allowed_token_ids_mask
)
vllm/v1/engine/processor.py
View file @
bb78fb31
...
...
@@ -83,6 +83,19 @@ class Processor:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
def
_validate_allowed_token_ids
(
self
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
None
:
if
not
isinstance
(
params
,
SamplingParams
):
return
if
params
.
allowed_token_ids
is
None
:
return
if
not
all
(
0
<=
tid
<
self
.
model_config
.
vocab_size
for
tid
in
params
.
allowed_token_ids
):
raise
ValueError
(
"allowed_token_ids contains out-of-vocab token id"
)
def
process_inputs
(
self
,
request_id
:
str
,
...
...
@@ -100,6 +113,7 @@ class Processor:
self
.
_validate_logprobs
(
params
)
self
.
_validate_lora
(
lora_request
)
self
.
_validate_allowed_token_ids
(
params
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
...
...
vllm/v1/sample/metadata.py
View file @
bb78fb31
...
...
@@ -37,3 +37,7 @@ class SamplingMetadata:
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
vllm/v1/sample/sampler.py
View file @
bb78fb31
...
...
@@ -47,6 +47,8 @@ class Sampler(nn.Module):
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
# Apply allowed token ids.
logits
=
self
.
apply_allowed_token_ids
(
logits
,
sampling_metadata
)
# Apply logits bias.
logits
=
self
.
apply_logits_bias
(
logits
,
sampling_metadata
)
# Apply penalties (e.g., min_tokens, freq_penalties).
...
...
@@ -184,11 +186,13 @@ class Sampler(nn.Module):
if
not
sampling_metadata
.
no_penalties
:
assert
sampling_metadata
.
prompt_token_ids
is
not
None
logits
=
apply_all_penalties
(
logits
,
sampling_metadata
.
prompt_token_ids
,
logits
,
sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
presence_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
output_token_ids
)
sampling_metadata
.
output_token_ids
,
)
return
logits
def
apply_min_p
(
...
...
@@ -226,3 +230,13 @@ class Sampler(nn.Module):
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
i
,
token_id
]
+=
bias
return
logits
def
apply_allowed_token_ids
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
if
sampling_metadata
.
allowed_token_ids_mask
is
not
None
:
logits
.
masked_fill_
(
sampling_metadata
.
allowed_token_ids_mask
,
float
(
"-inf"
))
return
logits
vllm/v1/worker/gpu_input_batch.py
View file @
bb78fb31
...
...
@@ -143,7 +143,7 @@ class InputBatch:
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_reqs
:
Set
[
str
]
=
set
()
# Presence penalty related data structures
...
...
@@ -168,7 +168,7 @@ class InputBatch:
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
Set
[
str
]
=
set
()
# req_index -> (min_tokens, stop_token_ids)
...
...
@@ -192,6 +192,9 @@ class InputBatch:
self
.
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
self
.
has_allowed_token_ids
:
Set
[
str
]
=
set
()
self
.
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
self
.
allowed_token_ids_mask_cpu_tensor
:
Optional
[
torch
.
Tensor
]
=
None
self
.
req_output_token_ids
:
List
[
Optional
[
List
[
int
]]]
=
[]
...
...
@@ -287,6 +290,22 @@ class InputBatch:
if
sampling_params
.
logit_bias
is
not
None
:
self
.
logit_bias
[
req_index
]
=
sampling_params
.
logit_bias
if
sampling_params
.
allowed_token_ids
:
self
.
has_allowed_token_ids
.
add
(
req_id
)
if
self
.
allowed_token_ids_mask_cpu_tensor
is
None
:
# Lazy allocation for this tensor, which can be large.
self
.
allowed_token_ids_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
)
self
.
allowed_token_ids_mask_cpu_tensor
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
][
sampling_params
.
allowed_token_ids
]
=
True
# Add request lora ID
if
request
.
lora_request
:
lora_id
=
request
.
lora_request
.
lora_int_id
...
...
@@ -332,6 +351,9 @@ class InputBatch:
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
logit_bias
[
req_index
]
=
None
self
.
has_allowed_token_ids
.
discard
(
req_id
)
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
].
fill_
(
False
)
return
req_index
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
...
...
@@ -400,6 +422,11 @@ class InputBatch:
self
.
logit_bias
[
empty_index
]
=
self
.
logit_bias
[
last_req_index
]
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
empty_index
]
=
self
.
allowed_token_ids_mask_cpu_tensor
[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
...
...
@@ -442,6 +469,13 @@ class InputBatch:
else
:
prompt_token_ids
=
None
allowed_token_ids_mask
:
Optional
[
torch
.
Tensor
]
=
None
if
not
self
.
no_allowed_token_ids
:
assert
self
.
allowed_token_ids_mask
is
not
None
copy_slice
(
self
.
allowed_token_ids_mask_cpu_tensor
,
self
.
allowed_token_ids_mask
,
num_reqs
)
allowed_token_ids_mask
=
self
.
allowed_token_ids_mask
[:
num_reqs
]
return
SamplingMetadata
(
temperature
=
temperature
,
all_greedy
=
self
.
all_greedy
,
...
...
@@ -460,6 +494,7 @@ class InputBatch:
min_tokens
=
self
.
min_tokens
,
no_penalties
=
self
.
no_penalties
,
logit_bias
=
self
.
logit_bias
[:
num_reqs
],
allowed_token_ids_mask
=
allowed_token_ids_mask
,
)
def
get_sampling_metadata
(
...
...
@@ -550,3 +585,7 @@ class InputBatch:
@
property
def
no_prompt_logprob
(
self
)
->
bool
:
return
not
self
.
num_prompt_logprobs
@
property
def
no_allowed_token_ids
(
self
)
->
bool
:
return
len
(
self
.
has_allowed_token_ids
)
==
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment