Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9f69d824
Unverified
Commit
9f69d824
authored
Jul 29, 2024
by
Nick Hill
Committed by
GitHub
Jul 29, 2024
Browse files
[Frontend] New `allowed_token_ids` decoding request parameter (#6753)
parent
9a7e2d05
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
114 additions
and
46 deletions
+114
-46
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+22
-0
vllm/entrypoints/openai/logits_processors.py
vllm/entrypoints/openai/logits_processors.py
+74
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+16
-44
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+1
-1
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-1
No files found.
tests/entrypoints/openai/test_completion.py
View file @
9f69d824
...
...
@@ -541,6 +541,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI):
assert
first_response
!=
completion
.
choices
[
0
].
text
@
pytest
.
mark
.
asyncio
async
def
test_allowed_token_ids
(
client
:
openai
.
AsyncOpenAI
):
prompt
=
"Hello, my name is"
max_tokens
=
1
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
)
# Test exclusive selection
allowed_ids
=
[
21555
,
21557
,
21558
]
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
temperature
=
0.0
,
seed
=
42
,
extra_body
=
dict
(
allowed_token_ids
=
allowed_ids
),
logprobs
=
1
,
)
response_tokens
=
completion
.
choices
[
0
].
logprobs
.
tokens
assert
len
(
response_tokens
)
==
1
assert
tokenizer
.
convert_tokens_to_ids
(
response_tokens
)[
0
]
in
allowed_ids
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
...
...
vllm/entrypoints/openai/logits_processors.py
0 → 100644
View file @
9f69d824
from
functools
import
lru_cache
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Union
import
torch
from
transformers
import
PreTrainedTokenizer
from
vllm.sampling_params
import
LogitsProcessor
class
AllowedTokenIdsLogitsProcessor
:
"""Logits processor for constraining generated tokens to a
specific set of token ids."""
def
__init__
(
self
,
allowed_ids
:
Iterable
[
int
]):
self
.
allowed_ids
:
Optional
[
List
[
int
]]
=
list
(
allowed_ids
)
self
.
mask
:
Optional
[
torch
.
Tensor
]
=
None
def
__call__
(
self
,
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
mask
is
None
:
self
.
mask
=
torch
.
ones
((
logits
.
shape
[
-
1
],
),
dtype
=
torch
.
bool
,
device
=
logits
.
device
)
self
.
mask
[
self
.
allowed_ids
]
=
False
self
.
allowed_ids
=
None
logits
.
masked_fill_
(
self
.
mask
,
float
(
"-inf"
))
return
logits
@
lru_cache
(
maxsize
=
32
)
def
_get_allowed_token_ids_logits_processor
(
allowed_token_ids
:
FrozenSet
[
int
],
vocab_size
:
int
,
)
->
LogitsProcessor
:
if
not
allowed_token_ids
:
raise
ValueError
(
"Empty allowed_token_ids provided"
)
if
not
all
(
0
<=
tid
<
vocab_size
for
tid
in
allowed_token_ids
):
raise
ValueError
(
"allowed_token_ids contains "
"out-of-vocab token id"
)
return
AllowedTokenIdsLogitsProcessor
(
allowed_token_ids
)
def
get_logits_processors
(
logit_bias
:
Optional
[
Union
[
Dict
[
int
,
float
],
Dict
[
str
,
float
]]],
allowed_token_ids
:
Optional
[
List
[
int
]],
tokenizer
:
PreTrainedTokenizer
)
->
List
[
LogitsProcessor
]:
logits_processors
=
[]
if
logit_bias
:
try
:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias
:
Dict
[
int
,
float
]
=
{
int
(
token_id
):
min
(
100.0
,
max
(
-
100.0
,
bias
))
for
token_id
,
bias
in
logit_bias
.
items
()
}
except
ValueError
as
exc
:
raise
ValueError
(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer"
)
from
exc
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
clamped_logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
logits_processors
.
append
(
logit_bias_logits_processor
)
if
allowed_token_ids
is
not
None
:
logits_processors
.
append
(
_get_allowed_token_ids_logits_processor
(
frozenset
(
allowed_token_ids
),
tokenizer
.
vocab_size
))
return
logits_processors
vllm/entrypoints/openai/protocol.py
View file @
9f69d824
...
...
@@ -5,9 +5,11 @@ from typing import Any, Dict, List, Literal, Optional, Union
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
transformers
import
PreTrainedTokenizer
from
typing_extensions
import
Annotated
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
...
...
@@ -213,30 +215,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def
to_sampling_params
(
self
)
->
SamplingParams
:
def
to_sampling_params
(
self
,
tokenizer
:
PreTrainedTokenizer
)
->
SamplingParams
:
# We now allow logprobs being true without top_logrobs.
logits_processors
=
None
if
self
.
logit_bias
:
logit_bias
:
Dict
[
int
,
float
]
=
{}
try
:
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
# Convert token_id to integer before we add to LLMEngine
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias
[
int
(
token_id
)]
=
min
(
100
,
max
(
-
100
,
bias
))
except
ValueError
as
exc
:
raise
ValueError
(
f
"Found token_id `
{
token_id
}
` in logit_bias "
f
"but token_id must be an integer or string "
f
"representing an integer"
)
from
exc
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
logits_processors
=
get_logits_processors
(
logit_bias
=
self
.
logit_bias
,
allowed_token_ids
=
None
,
tokenizer
=
tokenizer
,
)
return
SamplingParams
(
n
=
self
.
n
,
...
...
@@ -358,6 +345,7 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
...
...
@@ -407,30 +395,14 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def
to_sampling_params
(
self
):
def
to_sampling_params
(
self
,
tokenizer
:
PreTrainedTokenizer
):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
logits_processors
=
None
if
self
.
logit_bias
:
logit_bias
:
Dict
[
int
,
float
]
=
{}
try
:
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias
[
int
(
token_id
)]
=
min
(
100
,
max
(
-
100
,
bias
))
except
ValueError
as
exc
:
raise
ValueError
(
f
"Found token_id `
{
token_id
}
` in logit_bias "
f
"but token_id must be an integer or string "
f
"representing an integer"
)
from
exc
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
logits_processors
=
get_logits_processors
(
logit_bias
=
self
.
logit_bias
,
allowed_token_ids
=
self
.
allowed_token_ids
,
tokenizer
=
tokenizer
,
)
return
SamplingParams
(
n
=
self
.
n
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
9f69d824
...
...
@@ -134,7 +134,7 @@ class OpenAIServingChat(OpenAIServing):
request_id
=
f
"chat-
{
random_uuid
()
}
"
try
:
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
(
tokenizer
)
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
9f69d824
...
...
@@ -95,7 +95,7 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
(
tokenizer
)
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment