Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
555bdcc5
Unverified
Commit
555bdcc5
authored
Nov 03, 2023
by
Noam Gat
Committed by
GitHub
Nov 03, 2023
Browse files
Added logits processor API to sampling params (#1469)
parent
54ca1ba7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
2 deletions
+70
-2
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+34
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+24
-0
vllm/sampling_params.py
vllm/sampling_params.py
+12
-2
No files found.
tests/samplers/test_sampler.py
View file @
555bdcc5
...
...
@@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int):
continue
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
in
expected_tokens
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_logits_processors
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def
pick_ith
(
token_ids
,
logits
):
logits
[
len
(
token_ids
)]
=
float
(
"inf"
)
return
logits
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
))
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input_metadata
=
input_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
nth_output
.
output_token
==
idx
vllm/model_executor/layers/sampler.py
View file @
555bdcc5
...
...
@@ -47,6 +47,8 @@ class Sampler(nn.Module):
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
self
.
vocab_size
)
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
input_metadata
)
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
...
...
@@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
return
output_tokens
def
_apply_logits_processors
(
logits
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
)
->
torch
.
Tensor
:
logits_row_idx
=
0
found_logits_processors
=
False
for
seq_ids
,
sampling_params
in
input_metadata
.
seq_groups
:
logits_processors
=
sampling_params
.
logits_processors
if
logits_processors
:
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
logits_row
=
logits
[
logits_row_idx
]
token_ids
=
input_metadata
.
seq_data
[
seq_id
].
output_token_ids
for
logits_processor
in
logits_processors
:
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
logits_row_idx
+=
1
else
:
logits_row_idx
+=
len
(
seq_ids
)
if
found_logits_processors
:
assert
logits_row_idx
==
logits
.
shape
[
0
]
return
logits
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
output_tokens
:
List
[
List
[
int
]],
...
...
vllm/sampling_params.py
View file @
555bdcc5
"""Sampling parameters for text generation."""
from
enum
import
IntEnum
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
_SAMPLING_EPS
=
1e-5
...
...
@@ -12,6 +13,12 @@ class SamplingType(IntEnum):
BEAM
=
2
LogitsProcessor
=
Callable
[[
List
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""
class
SamplingParams
:
"""Sampling parameters for text generation.
...
...
@@ -73,6 +80,8 @@ class SamplingParams:
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
"""
def
__init__
(
...
...
@@ -96,6 +105,7 @@ class SamplingParams:
prompt_logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
)
->
None
:
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
...
@@ -124,7 +134,7 @@ class SamplingParams:
self
.
prompt_logprobs
=
prompt_logprobs
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
_verify_args
()
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment