Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6224a9f6
Unverified
Commit
6224a9f6
authored
Feb 14, 2025
by
Lu Fang
Committed by
GitHub
Feb 14, 2025
Browse files
Support logit_bias in v1 Sampler (#13079)
parent
085b7b2d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
200 additions
and
101 deletions
+200
-101
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+59
-12
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+80
-62
vllm/sampling_params.py
vllm/sampling_params.py
+3
-1
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+2
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+16
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+40
-26
No files found.
tests/v1/sample/test_sampler.py
View file @
6224a9f6
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
...
@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
)
)
def
_create_logit_bias
(
batch_size
:
int
,
vocab_size
:
int
,
bias_value
:
float
,
)
->
List
[
Optional
[
Dict
[
int
,
float
]]]:
res
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[]
for
i
in
range
(
batch_size
):
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
res
.
append
(
logit_bias
)
return
res
def
_create_default_sampling_metadata
(
def
_create_default_sampling_metadata
(
num_output_tokens
:
int
,
num_output_tokens
:
int
,
batch_size
:
int
,
batch_size
:
int
,
...
@@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
...
@@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties
=
True
,
no_penalties
=
True
,
min_tokens
=
[],
min_tokens
=
[],
stop_token_ids
=
[],
stop_token_ids
=
[],
logit_bias
=
[
None
]
*
batch_size
,
)
)
return
fake_sampling_metadata
return
fake_sampling_metadata
...
@@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
...
@@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
batch_indices_for_min_token_penalty
:
List
[
int
]
batch_indices_for_min_token_penalty
:
List
[
int
]
)
->
Tuple
[
List
[
int
],
List
[
Set
[
int
]]]:
)
->
Tuple
[
List
[
int
],
List
[
Set
[
int
]]]:
"""
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
and a corresponding list of stop token IDs (`stop_token_ids`) for each
batch.
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
"""
stop_token_ids
:
List
[
Set
[
int
]]
=
[]
stop_token_ids
:
List
[
Set
[
int
]]
=
[]
min_tokens
:
List
[
int
]
=
[]
min_tokens
:
List
[
int
]
=
[]
...
@@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
...
@@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
batch_size
:
int
,
batch_size
:
int
,
vocab_size
:
int
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
vocab_size
:
int
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
"""
"""
Creates an output token list where each token occurs a distinct
Creates an output token list where each token occurs a distinct
number of times.
number of times.
For each batch, a random subset of token IDs is selected from the
For each batch, a random subset of token IDs is selected from the
...
@@ -129,8 +142,8 @@ def _create_weighted_output_token_list(
...
@@ -129,8 +142,8 @@ def _create_weighted_output_token_list(
Returns:
Returns:
Tuple[List[List[int]], List[List[int]]]:
Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
corresponds to a batch and contains tokens with weighted
frequencies.
frequencies.
- The second element is a list of distinct token IDs for each
- The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output
batch, ordered by their frequency in the corresponding output
...
@@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
...
@@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
def
test_sampler_min_tokens_penalty
(
device
:
str
,
batch_size
:
int
):
def
test_sampler_min_tokens_penalty
(
device
:
str
,
batch_size
:
int
):
"""
"""
Tests that if the number of output tokens is less than
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
the stop token ids to -inf.
"""
"""
...
@@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
...
@@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
def
test_sampler_repetition_penalty
(
device
:
str
,
batch_size
:
int
,
def
test_sampler_repetition_penalty
(
device
:
str
,
batch_size
:
int
,
repetition_penalty
:
float
):
repetition_penalty
:
float
):
"""
"""
Test to verify that when the repetition penalty is enabled, tokens
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
are penalized based on their presence in the prompt or the existing
output.
output.
"""
"""
...
@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
...
@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id
not
in
output_tokens
)
penalized_token_id
not
in
output_tokens
)
assert
(
non_penalized_token_id
in
prompt_tokens
or
\
assert
(
non_penalized_token_id
in
prompt_tokens
or
\
non_penalized_token_id
in
output_tokens
)
non_penalized_token_id
in
output_tokens
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"bias_value"
,
[
-
0.1
,
1.2
])
def
test_sampler_logit_bias
(
device
:
str
,
batch_size
:
int
,
bias_value
:
float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch
.
set_default_device
(
device
)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
sampling_metadata
.
logit_bias
=
_create_logit_bias
(
batch_size
=
batch_size
,
vocab_size
=
VOCAB_SIZE
,
bias_value
=
bias_value
,
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_logits_bias
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logits_for_req
=
logits
[
batch_idx
]
biased_index
=
min
(
batch_idx
,
VOCAB_SIZE
-
1
)
for
token_id
in
range
(
VOCAB_SIZE
):
if
biased_index
==
token_id
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
bias_value
+
1e-2
)
else
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
1e-2
)
tests/v1/worker/test_gpu_input_batch.py
View file @
6224a9f6
...
@@ -45,9 +45,11 @@ def _remove_requests(
...
@@ -45,9 +45,11 @@ def _remove_requests(
def
_construct_expected_sampling_metadata
(
def
_construct_expected_sampling_metadata
(
reqs
:
List
[
CachedRequestState
],
req_ids_retained
:
Set
[
int
],
reqs
:
List
[
CachedRequestState
],
req_id_index_in_input_batch
:
Dict
[
str
,
int
],
req_ids_retained
:
Set
[
int
],
device
:
torch
.
device
)
->
SamplingMetadata
:
req_id_index_in_input_batch
:
Dict
[
str
,
int
],
device
:
torch
.
device
,
)
->
SamplingMetadata
:
"""
"""
Constructs and returns the expected SamplingMetadata for this
Constructs and returns the expected SamplingMetadata for this
batch.
batch.
...
@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
...
@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
stop_token_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
num_reqs
)]
stop_token_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
num_reqs
)]
min_tokens
=
[
0
for
_
in
range
(
num_reqs
)]
min_tokens
=
[
0
for
_
in
range
(
num_reqs
)]
logit_bias
=
[
None
]
*
num_reqs
for
req
in
reqs
:
for
req
in
reqs
:
if
req
.
req_id
not
in
req_ids_retained
:
if
req
.
req_id
not
in
req_ids_retained
:
continue
continue
...
@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
...
@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
prompt_token_ids
[
index_in_input_batch
]
=
req
.
prompt_token_ids
prompt_token_ids
[
index_in_input_batch
]
=
req
.
prompt_token_ids
presence_penalties
[
presence_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
presence_penalty
index_in_input_batch
]
=
req
.
sampling_params
.
presence_penalty
frequency_penalties
[
frequency_penalties
[
index_in_input_batch
]
=
(
index_in_input_batch
]
=
req
.
sampling_params
.
frequency_penalty
req
.
sampling_params
.
frequency_penalty
)
repetition_penalties
[
repetition_penalties
[
index_in_input_batch
]
=
(
index_in_input_batch
]
=
req
.
sampling_params
.
repetition_penalty
req
.
sampling_params
.
repetition_penalty
)
top_k
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_k
top_k
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_k
top_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_p
top_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_p
temperature
[
index_in_input_batch
]
=
req
.
sampling_params
.
temperature
temperature
[
index_in_input_batch
]
=
req
.
sampling_params
.
temperature
stop_token_ids
[
stop_token_ids
[
index_in_input_batch
]
=
req
.
sampling_params
.
all_stop_token_ids
index_in_input_batch
]
=
req
.
sampling_params
.
all_stop_token_ids
min_tokens
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_tokens
min_tokens
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_tokens
logit_bias
[
index_in_input_batch
]
=
req
.
sampling_params
.
logit_bias
return
SamplingMetadata
(
return
SamplingMetadata
(
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
all_greedy
=
False
,
all_greedy
=
False
,
all_random
=
True
,
all_random
=
True
,
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
...
@@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata(
...
@@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata(
no_top_k
=
all
(
x
==
0
for
x
in
top_k
),
no_top_k
=
all
(
x
==
0
for
x
in
top_k
),
generators
=
{},
generators
=
{},
max_num_logprobs
=
0
,
max_num_logprobs
=
0
,
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
,
prompt_token_ids
,
pad
=
VOCAB_SIZE
,
pad
=
VOCAB_SIZE
,
device
=
torch
.
device
(
device
),
device
=
torch
.
device
(
device
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
),
),
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
frequency_penalties
,
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
device
),
device
=
device
),
presence_penalties
=
torch
.
tensor
(
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
presence_penalties
,
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
device
),
device
=
device
),
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
repetition_penalties
,
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
device
),
device
=
device
),
output_token_ids
=
output_token_ids
,
output_token_ids
=
output_token_ids
,
min_tokens
=
min_tokens
,
min_tokens
=
min_tokens
,
stop_token_ids
=
stop_token_ids
,
stop_token_ids
=
stop_token_ids
,
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
\
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
all
(
x
==
0
for
x
in
frequency_penalties
)
and
\
and
all
(
x
==
0
for
x
in
frequency_penalties
)
all
(
x
==
1
for
x
in
repetition_penalties
))
and
all
(
x
==
1
for
x
in
repetition_penalties
)),
logit_bias
=
logit_bias
,
)
)
def
_create_sampling_params
():
def
_create_sampling_params
():
return
SamplingParams
(
top_k
=
np
.
random
.
randint
(
1
,
10
),
return
SamplingParams
(
top_p
=
np
.
random
.
uniform
(
0.0
,
1.0
),
top_k
=
np
.
random
.
randint
(
1
,
10
),
presence_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
top_p
=
np
.
random
.
uniform
(
0.0
,
1.0
),
repetition_penalty
=
np
.
random
.
uniform
(
0.0
,
2.0
),
presence_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
frequency_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
repetition_penalty
=
np
.
random
.
uniform
(
0.0
,
2.0
),
min_tokens
=
np
.
random
.
randint
(
1
,
10
),
frequency_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
stop_token_ids
=
[
min_tokens
=
np
.
random
.
randint
(
1
,
10
),
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
stop_token_ids
=
[
for
_
in
range
(
np
.
random
.
randint
(
10
))
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
])
for
_
in
range
(
np
.
random
.
randint
(
10
))
],
logit_bias
=
{
0
:
np
.
random
.
uniform
(
-
3.0
,
3.0
)},
)
def
_construct_cached_request_state
(
req_id_suffix
:
int
):
def
_construct_cached_request_state
(
req_id_suffix
:
int
):
...
@@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int):
...
@@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int):
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
NUM_OUTPUT_TOKENS
))
for
_
in
range
(
np
.
random
.
randint
(
0
,
NUM_OUTPUT_TOKENS
))
]
]
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
return
CachedRequestState
(
prompt_token_ids
=
prompt_token_ids
,
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt
=
None
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
_create_sampling_params
(),
prompt
=
None
,
mm_inputs
=
[],
sampling_params
=
_create_sampling_params
(),
mm_positions
=
[],
mm_inputs
=
[],
block_ids
=
[],
mm_positions
=
[],
generator
=
None
,
block_ids
=
[],
num_computed_tokens
=
len
(
output_token_ids
),
generator
=
None
,
output_token_ids
=
output_token_ids
)
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
,
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
...
@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
results to ensure correctness.
"""
"""
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
input_batch
:
InputBatch
=
InputBatch
(
max_model_len
=
1024
,
max_num_reqs
=
batch_size
,
max_num_blocks_per_req
=
10
,
max_model_len
=
1024
,
device
=
torch
.
device
(
device
),
max_num_blocks_per_req
=
10
,
pin_memory
=
is_pin_memory_available
(),
device
=
torch
.
device
(
device
),
vocab_size
=
1024
)
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
)
reqs
:
List
[
CachedRequestState
]
=
[]
reqs
:
List
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
req_id_reqs
=
{}
req_id_output_token_ids
=
{}
req_id_output_token_ids
=
{}
...
@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
...
@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata
.
top_p
)
sampling_metadata
.
top_p
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_k
,
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
sampling_metadata
.
top_k
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
assert
torch
.
allclose
(
sampling_metadata
.
frequency_penalties
)
expected_sampling_metadata
.
frequency_penalties
,
assert
torch
.
allclose
(
expected_sampling_metadata
.
presence_penalties
,
sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
presence_penalties
)
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
repetition_penalties
,
assert
torch
.
allclose
(
sampling_metadata
.
repetition_penalties
)
expected_sampling_metadata
.
presence_penalties
,
sampling_metadata
.
presence_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
prompt_token_ids
,
assert
torch
.
allclose
(
expected_sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
prompt_token_ids
)
sampling_metadata
.
prompt_token_ids
)
assert
(
expected_sampling_metadata
.
output_token_ids
==
assert
(
expected_sampling_metadata
.
output_token_ids
==
sampling_metadata
.
output_token_ids
)
sampling_metadata
.
output_token_ids
)
assert
(
assert
expected_sampling_metadata
.
min_tokens
==
sampling_metadata
.
min_tokens
expected_sampling_metadata
.
min
_tokens
==
sampling_metadata
.
min_tokens
)
assert
expected_sampling_metadata
.
stop
_token
_id
s
==
\
assert
(
expected_
sampling_metadata
.
stop_token_ids
==
sampling_metadata
.
stop_token_ids
sampling_metadata
.
stop_token_ids
)
assert
expected_
sampling_metadata
.
no_penalties
==
\
assert
(
expected_
sampling_metadata
.
no_penalties
==
sampling_metadata
.
no_penalties
sampling_metadata
.
no_
penalties
)
assert
expected_
sampling_metadata
.
no_
top_p
==
sampling_metadata
.
no_top_p
assert
(
expected_sampling_metadata
.
no_top_
p
==
sampling_metadata
.
no_top_
p
)
assert
expected_sampling_metadata
.
no_top_
k
==
sampling_metadata
.
no_top_
k
assert
(
expected_sampling_metadata
.
no_top_k
==
sampling_metadata
.
no_top_k
)
assert
expected_sampling_metadata
.
logit_bias
==
sampling_metadata
.
logit_bias
vllm/sampling_params.py
View file @
6224a9f6
...
@@ -243,8 +243,10 @@ class SamplingParams(
...
@@ -243,8 +243,10 @@ class SamplingParams(
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
,
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
"SamplingParams"
:
)
->
"SamplingParams"
:
if
logit_bias
is
not
None
:
if
logit_bias
is
not
None
:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias
=
{
logit_bias
=
{
int
(
token
):
bias
int
(
token
):
min
(
100.0
,
max
(
-
100.0
,
bias
))
for
token
,
bias
in
logit_bias
.
items
()
for
token
,
bias
in
logit_bias
.
items
()
}
}
...
...
vllm/v1/sample/metadata.py
View file @
6224a9f6
...
@@ -32,3 +32,5 @@ class SamplingMetadata:
...
@@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids
:
List
[
List
[
int
]]
output_token_ids
:
List
[
List
[
int
]]
min_tokens
:
List
[
int
]
min_tokens
:
List
[
int
]
stop_token_ids
:
List
[
Set
[
int
]]
stop_token_ids
:
List
[
Set
[
int
]]
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
vllm/v1/sample/sampler.py
View file @
6224a9f6
...
@@ -37,6 +37,8 @@ class Sampler(nn.Module):
...
@@ -37,6 +37,8 @@ class Sampler(nn.Module):
# Use float32 for the logits.
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
logits
=
logits
.
to
(
torch
.
float32
)
# Apply logits bias.
logits
=
self
.
apply_logits_bias
(
logits
,
sampling_metadata
)
# Apply penalties (e.g., min_tokens, freq_penalties).
# Apply penalties (e.g., min_tokens, freq_penalties).
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
# Apply temperature.
# Apply temperature.
...
@@ -166,3 +168,17 @@ class Sampler(nn.Module):
...
@@ -166,3 +168,17 @@ class Sampler(nn.Module):
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
output_token_ids
)
sampling_metadata
.
output_token_ids
)
return
logits
return
logits
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for
i
,
logit_bias
in
enumerate
(
sampling_metadata
.
logit_bias
):
if
logit_bias
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
i
,
token_id
]
+=
bias
return
logits
vllm/v1/worker/gpu_input_batch.py
View file @
6224a9f6
...
@@ -130,7 +130,7 @@ class InputBatch:
...
@@ -130,7 +130,7 @@ class InputBatch:
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
frequency_penalties_reqs
:
Set
[
str
]
=
set
()
# Presence penalty related data structures
# Presence penalty related data structures
...
@@ -141,8 +141,8 @@ class InputBatch:
...
@@ -141,8 +141,8 @@ class InputBatch:
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
presence_penalties_cpu
=
\
self
.
presence_penalties_cpu
=
self
.
presence_penalties_cpu_tensor
.
numpy
(
self
.
presence_penalties_cpu_tensor
.
numpy
(
)
)
self
.
presence_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
presence_penalties_reqs
:
Set
[
str
]
=
set
()
# Repetition penalty related data structures
# Repetition penalty related data structures
...
@@ -155,7 +155,7 @@ class InputBatch:
...
@@ -155,7 +155,7 @@ class InputBatch:
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
)
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
repetition_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
min_tokens
:
List
[
int
]
=
[
0
]
*
max_num_reqs
self
.
min_tokens
:
List
[
int
]
=
[
0
]
*
max_num_reqs
...
@@ -180,6 +180,9 @@ class InputBatch:
...
@@ -180,6 +180,9 @@ class InputBatch:
# that are currently in the prefill phase.
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
Dict
[
str
,
int
]
=
{}
self
.
num_prompt_logprobs
:
Dict
[
str
,
int
]
=
{}
self
.
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
def
add_request
(
def
add_request
(
self
,
self
,
request
:
"CachedRequestState"
,
request
:
"CachedRequestState"
,
...
@@ -220,16 +223,16 @@ class InputBatch:
...
@@ -220,16 +223,16 @@ class InputBatch:
self
.
top_k_cpu
[
req_index
]
=
sampling_params
.
top_k
self
.
top_k_cpu
[
req_index
]
=
sampling_params
.
top_k
if
sampling_params
.
top_k
>
0
:
if
sampling_params
.
top_k
>
0
:
self
.
top_k_reqs
.
add
(
req_id
)
self
.
top_k_reqs
.
add
(
req_id
)
self
.
frequency_penalties_cpu
[
req_index
]
=
\
self
.
frequency_penalties_cpu
[
sampling_params
.
frequency_penalty
req_index
]
=
sampling_params
.
frequency_penalty
if
sampling_params
.
frequency_penalty
!=
0.0
:
if
sampling_params
.
frequency_penalty
!=
0.0
:
self
.
frequency_penalties_reqs
.
add
(
req_id
)
self
.
frequency_penalties_reqs
.
add
(
req_id
)
self
.
presence_penalties_cpu
[
req_index
]
=
\
self
.
presence_penalties_cpu
[
sampling_params
.
presence_penalty
req_index
]
=
sampling_params
.
presence_penalty
if
sampling_params
.
presence_penalty
!=
0.0
:
if
sampling_params
.
presence_penalty
!=
0.0
:
self
.
presence_penalties_reqs
.
add
(
req_id
)
self
.
presence_penalties_reqs
.
add
(
req_id
)
self
.
repetition_penalties_cpu
[
req_index
]
=
\
self
.
repetition_penalties_cpu
[
sampling_params
.
repetition_penalty
req_index
]
=
sampling_params
.
repetition_penalty
if
sampling_params
.
repetition_penalty
!=
1.0
:
if
sampling_params
.
repetition_penalty
!=
1.0
:
self
.
repetition_penalties_reqs
.
add
(
req_id
)
self
.
repetition_penalties_reqs
.
add
(
req_id
)
self
.
min_tokens
[
req_index
]
=
sampling_params
.
min_tokens
self
.
min_tokens
[
req_index
]
=
sampling_params
.
min_tokens
...
@@ -244,6 +247,8 @@ class InputBatch:
...
@@ -244,6 +247,8 @@ class InputBatch:
self
.
num_logprobs
[
req_id
]
=
sampling_params
.
logprobs
self
.
num_logprobs
[
req_id
]
=
sampling_params
.
logprobs
if
sampling_params
.
prompt_logprobs
is
not
None
:
if
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
sampling_params
.
prompt_logprobs
self
.
num_prompt_logprobs
[
req_id
]
=
sampling_params
.
prompt_logprobs
if
sampling_params
.
logit_bias
is
not
None
:
self
.
logit_bias
[
req_index
]
=
sampling_params
.
logit_bias
# Add request lora ID
# Add request lora ID
if
request
.
lora_request
:
if
request
.
lora_request
:
...
@@ -284,6 +289,7 @@ class InputBatch:
...
@@ -284,6 +289,7 @@ class InputBatch:
self
.
lora_id_to_lora_request
.
pop
(
lora_id
)
self
.
lora_id_to_lora_request
.
pop
(
lora_id
)
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
logit_bias
[
req_index
]
=
None
return
req_index
return
req_index
def
clear
(
self
)
->
None
:
def
clear
(
self
)
->
None
:
...
@@ -302,6 +308,7 @@ class InputBatch:
...
@@ -302,6 +308,7 @@ class InputBatch:
self
.
request_lora_mapping
.
fill
(
0
)
self
.
request_lora_mapping
.
fill
(
0
)
self
.
lora_id_to_lora_request
.
clear
()
self
.
lora_id_to_lora_request
.
clear
()
self
.
lora_id_to_request_ids
.
clear
()
self
.
lora_id_to_request_ids
.
clear
()
self
.
logit_bias
=
[
None
]
*
self
.
max_num_reqs
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
if
self
.
num_reqs
==
0
:
if
self
.
num_reqs
==
0
:
...
@@ -332,8 +339,8 @@ class InputBatch:
...
@@ -332,8 +339,8 @@ class InputBatch:
self
.
token_ids_cpu
[
empty_index
,
:
num_tokens
]
=
self
.
token_ids_cpu
[
self
.
token_ids_cpu
[
empty_index
,
:
num_tokens
]
=
self
.
token_ids_cpu
[
last_req_index
,
:
num_tokens
]
last_req_index
,
:
num_tokens
]
self
.
num_tokens
[
empty_index
]
=
num_tokens
self
.
num_tokens
[
empty_index
]
=
num_tokens
self
.
num_prompt_tokens
[
empty_index
]
=
\
self
.
num_prompt_tokens
[
empty_index
]
=
self
.
num_prompt_tokens
[
self
.
num_prompt_tokens
[
last_req_index
]
last_req_index
]
self
.
num_computed_tokens_cpu
[
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
...
@@ -341,15 +348,15 @@ class InputBatch:
...
@@ -341,15 +348,15 @@ class InputBatch:
last_req_index
]
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
frequency_penalties_cpu
[
empty_index
]
=
\
self
.
frequency_penalties_cpu
[
self
.
frequency_penalties_cpu
[
last_req_index
]
empty_index
]
=
self
.
frequency_penalties_cpu
[
last_req_index
]
self
.
presence_penalties_cpu
[
empty_index
]
=
\
self
.
presence_penalties_cpu
[
self
.
presence_penalties_cpu
[
last_req_index
]
empty_index
]
=
self
.
presence_penalties_cpu
[
last_req_index
]
self
.
repetition_penalties_cpu
[
empty_index
]
=
\
self
.
repetition_penalties_cpu
[
self
.
repetition_penalties_cpu
[
last_req_index
]
empty_index
]
=
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
min_tokens
[
empty_index
]
=
self
.
min_tokens
[
last_req_index
]
self
.
min_tokens
[
empty_index
]
=
self
.
min_tokens
[
last_req_index
]
self
.
stop_token_ids
[
empty_index
]
=
\
self
.
stop_token_ids
[
empty_index
]
=
self
.
stop_token_ids
[
self
.
stop_token_ids
[
last_req_index
]
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
self
.
generators
[
empty_index
]
=
generator
...
@@ -357,6 +364,8 @@ class InputBatch:
...
@@ -357,6 +364,8 @@ class InputBatch:
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
last_req_index
]
self
.
logit_bias
[
empty_index
]
=
self
.
logit_bias
[
last_req_index
]
# Decrement last_req_index since it is now empty.
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
last_req_index
-=
1
...
@@ -378,13 +387,16 @@ class InputBatch:
...
@@ -378,13 +387,16 @@ class InputBatch:
# penalties to be applied during sampling.
# penalties to be applied during sampling.
self
.
frequency_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
frequency_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
frequency_penalties_cpu_tensor
[:
self
.
num_reqs
],
self
.
frequency_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
self
.
presence_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
presence_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
presence_penalties_cpu_tensor
[:
self
.
num_reqs
],
self
.
presence_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
self
.
repetition_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
repetition_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
repetition_penalties_cpu_tensor
[:
self
.
num_reqs
],
self
.
repetition_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
# The prompt tokens are used only for applying penalties during
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
# there are requests which need penalties to be applied.
...
@@ -421,6 +433,7 @@ class InputBatch:
...
@@ -421,6 +433,7 @@ class InputBatch:
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
no_penalties
=
self
.
no_penalties
,
no_penalties
=
self
.
no_penalties
,
logit_bias
=
self
.
logit_bias
[:
self
.
num_reqs
],
)
)
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
...
@@ -429,10 +442,11 @@ class InputBatch:
...
@@ -429,10 +442,11 @@ class InputBatch:
(
self
.
num_reqs
,
max_prompt_len
),
(
self
.
num_reqs
,
max_prompt_len
),
device
=
"cpu"
,
device
=
"cpu"
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
,
)
prompt_token_ids
=
prompt_token_ids_cpu_tensor
.
numpy
()
prompt_token_ids
=
prompt_token_ids_cpu_tensor
.
numpy
()
prompt_token_ids
[:]
=
(
prompt_token_ids
[:]
=
self
.
token_ids_cpu
[:
self
.
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
]
)
num_reqs
,
:
max_prompt_len
]
# Use the value of vocab_size as a pad since we don't have a
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
for
i
in
range
(
self
.
num_reqs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment