Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
555bdcc5
"vscode:/vscode.git/clone" did not exist on "3b8a14038cf615f5e96a5c85a09c2dfdb6351746"
Unverified
Commit
555bdcc5
authored
Nov 03, 2023
by
Noam Gat
Committed by
GitHub
Nov 03, 2023
Browse files
Added logits processor API to sampling params (#1469)
parent
54ca1ba7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
2 deletions
+70
-2
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+34
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+24
-0
vllm/sampling_params.py
vllm/sampling_params.py
+12
-2
No files found.
tests/samplers/test_sampler.py
View file @
555bdcc5
...
...
@@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int):
continue
for
nth_output
in
sequence_output
.
samples
:
assert
nth_output
.
output_token
in
expected_tokens
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_sampler_logits_processors
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def
pick_ith
(
token_ids
,
logits
):
logits
[
len
(
token_ids
)]
=
float
(
"inf"
)
return
logits
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
))
_
,
_
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
input_metadata
=
input_metadata
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
nth_output
.
output_token
==
idx
vllm/model_executor/layers/sampler.py
View file @
555bdcc5
...
...
@@ -47,6 +47,8 @@ class Sampler(nn.Module):
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
self
.
vocab_size
)
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
input_metadata
)
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
...
...
@@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
return
output_tokens
def
_apply_logits_processors
(
logits
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
)
->
torch
.
Tensor
:
logits_row_idx
=
0
found_logits_processors
=
False
for
seq_ids
,
sampling_params
in
input_metadata
.
seq_groups
:
logits_processors
=
sampling_params
.
logits_processors
if
logits_processors
:
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
logits_row
=
logits
[
logits_row_idx
]
token_ids
=
input_metadata
.
seq_data
[
seq_id
].
output_token_ids
for
logits_processor
in
logits_processors
:
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
logits_row_idx
+=
1
else
:
logits_row_idx
+=
len
(
seq_ids
)
if
found_logits_processors
:
assert
logits_row_idx
==
logits
.
shape
[
0
]
return
logits
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
output_tokens
:
List
[
List
[
int
]],
...
...
vllm/sampling_params.py
View file @
555bdcc5
"""Sampling parameters for text generation."""
from
enum
import
IntEnum
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
torch
_SAMPLING_EPS
=
1e-5
...
...
@@ -12,6 +13,12 @@ class SamplingType(IntEnum):
BEAM
=
2
LogitsProcessor
=
Callable
[[
List
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""
class
SamplingParams
:
"""Sampling parameters for text generation.
...
...
@@ -73,6 +80,8 @@ class SamplingParams:
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
"""
def
__init__
(
...
...
@@ -96,6 +105,7 @@ class SamplingParams:
prompt_logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
)
->
None
:
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
...
...
@@ -124,7 +134,7 @@ class SamplingParams:
self
.
prompt_logprobs
=
prompt_logprobs
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
_verify_args
()
if
self
.
use_beam_search
:
self
.
_verify_beam_search
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment