Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e0ade06d
"src/vscode:/vscode.git/clone" did not exist on "6c15636b0bd6abb7a295e63cc8fd009244e41811"
Unverified
Commit
e0ade06d
authored
Feb 26, 2024
by
Dylan Hawk
Committed by
GitHub
Feb 27, 2024
Browse files
Support logit bias for OpenAI API (#3027)
parent
4bd18ec0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
12 deletions
+83
-12
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+48
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+33
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+1
-7
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-5
No files found.
tests/entrypoints/test_openai_server.py
View file @
e0ade06d
...
...
@@ -9,6 +9,8 @@ import ray # using Ray for overall ease of process management, parallel request
import
openai
# use the official client for correctness check
from
huggingface_hub
import
snapshot_download
# downloading lora to test lora requests
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
MAX_SERVER_START_WAIT_S
=
600
# wait for server to start for 60 seconds
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
# any model with a chat template should work here
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
...
...
@@ -310,5 +312,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
assert
texts
[
0
]
==
texts
[
1
]
async
def
test_logits_bias
(
server
,
client
:
openai
.
AsyncOpenAI
):
prompt
=
"Hello, my name is"
max_tokens
=
5
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
)
# Test exclusive selection
token_id
=
1000
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
temperature
=
0.0
,
logit_bias
=
{
str
(
token_id
):
100
},
)
assert
completion
.
choices
[
0
].
text
is
not
None
and
len
(
completion
.
choices
[
0
].
text
)
>=
5
response_tokens
=
tokenizer
(
completion
.
choices
[
0
].
text
,
add_special_tokens
=
False
)[
"input_ids"
]
expected_tokens
=
tokenizer
(
tokenizer
.
decode
([
token_id
]
*
5
),
add_special_tokens
=
False
)[
"input_ids"
]
assert
all
([
response
==
expected
for
response
,
expected
in
zip
(
response_tokens
,
expected_tokens
)
])
# Test ban
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
temperature
=
0.0
,
)
response_tokens
=
tokenizer
(
completion
.
choices
[
0
].
text
,
add_special_tokens
=
False
)[
"input_ids"
]
first_response
=
completion
.
choices
[
0
].
text
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
temperature
=
0.0
,
logit_bias
=
{
str
(
token
):
-
100
for
token
in
response_tokens
},
)
assert
first_response
!=
completion
.
choices
[
0
].
text
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
vllm/entrypoints/openai/protocol.py
View file @
e0ade06d
...
...
@@ -8,6 +8,8 @@ from pydantic import BaseModel, Field
from
vllm.utils
import
random_uuid
from
vllm.sampling_params
import
SamplingParams
import
torch
class
ErrorResponse
(
BaseModel
):
object
:
str
=
"error"
...
...
@@ -88,6 +90,21 @@ class ChatCompletionRequest(BaseModel):
def
to_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
raise
ValueError
(
"Top logprobs must be set when logprobs is."
)
logits_processors
=
None
if
self
.
logit_bias
:
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias
=
min
(
100
,
max
(
-
100
,
bias
))
logits
[
int
(
token_id
)]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
return
SamplingParams
(
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
...
...
@@ -111,6 +128,7 @@ class ChatCompletionRequest(BaseModel):
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
)
...
...
@@ -149,6 +167,20 @@ class CompletionRequest(BaseModel):
def
to_sampling_params
(
self
):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
logits_processors
=
None
if
self
.
logit_bias
:
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias
=
min
(
100
,
max
(
-
100
,
bias
))
logits
[
int
(
token_id
)]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
return
SamplingParams
(
n
=
self
.
n
,
best_of
=
self
.
best_of
,
...
...
@@ -172,6 +204,7 @@ class CompletionRequest(BaseModel):
spaces_between_special_tokens
=
(
self
.
spaces_between_special_tokens
),
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
e0ade06d
...
...
@@ -39,19 +39,13 @@ class OpenAIServingChat(OpenAIServing):
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following feature
s
:
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
if
request
.
logit_bias
is
not
None
and
len
(
request
.
logit_bias
)
>
0
:
# TODO: support logit_bias in vLLM engine.
return
self
.
create_error_response
(
"logit_bias is not currently supported"
)
try
:
prompt
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
request
.
messages
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
e0ade06d
...
...
@@ -264,10 +264,9 @@ class OpenAIServingCompletion(OpenAIServing):
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature
s
:
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
...
...
@@ -277,9 +276,6 @@ class OpenAIServingCompletion(OpenAIServing):
if
request
.
suffix
is
not
None
:
return
self
.
create_error_response
(
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
and
len
(
request
.
logit_bias
)
>
0
:
return
self
.
create_error_response
(
"logit_bias is not currently supported"
)
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment