Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e0ade06d
Unverified
Commit
e0ade06d
authored
Feb 26, 2024
by
Dylan Hawk
Committed by
GitHub
Feb 27, 2024
Browse files
Support logit bias for OpenAI API (#3027)
parent
4bd18ec0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
12 deletions
+83
-12
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+48
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+33
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+1
-7
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-5
No files found.
tests/entrypoints/test_openai_server.py
View file @
e0ade06d
...
@@ -9,6 +9,8 @@ import ray # using Ray for overall ease of process management, parallel request
...
@@ -9,6 +9,8 @@ import ray # using Ray for overall ease of process management, parallel request
import
openai
# use the official client for correctness check
import
openai
# use the official client for correctness check
from
huggingface_hub
import
snapshot_download
# downloading lora to test lora requests
from
huggingface_hub
import
snapshot_download
# downloading lora to test lora requests
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
MAX_SERVER_START_WAIT_S
=
600
# wait for server to start for 60 seconds
MAX_SERVER_START_WAIT_S
=
600
# wait for server to start for 60 seconds
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
# any model with a chat template should work here
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
# any model with a chat template should work here
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
...
@@ -310,5 +312,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
...
@@ -310,5 +312,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
assert
texts
[
0
]
==
texts
[
1
]
assert
texts
[
0
]
==
texts
[
1
]
async
def
test_logits_bias
(
server
,
client
:
openai
.
AsyncOpenAI
):
prompt
=
"Hello, my name is"
max_tokens
=
5
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
)
# Test exclusive selection
token_id
=
1000
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
temperature
=
0.0
,
logit_bias
=
{
str
(
token_id
):
100
},
)
assert
completion
.
choices
[
0
].
text
is
not
None
and
len
(
completion
.
choices
[
0
].
text
)
>=
5
response_tokens
=
tokenizer
(
completion
.
choices
[
0
].
text
,
add_special_tokens
=
False
)[
"input_ids"
]
expected_tokens
=
tokenizer
(
tokenizer
.
decode
([
token_id
]
*
5
),
add_special_tokens
=
False
)[
"input_ids"
]
assert
all
([
response
==
expected
for
response
,
expected
in
zip
(
response_tokens
,
expected_tokens
)
])
# Test ban
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
temperature
=
0.0
,
)
response_tokens
=
tokenizer
(
completion
.
choices
[
0
].
text
,
add_special_tokens
=
False
)[
"input_ids"
]
first_response
=
completion
.
choices
[
0
].
text
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
temperature
=
0.0
,
logit_bias
=
{
str
(
token
):
-
100
for
token
in
response_tokens
},
)
assert
first_response
!=
completion
.
choices
[
0
].
text
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/entrypoints/openai/protocol.py
View file @
e0ade06d
...
@@ -8,6 +8,8 @@ from pydantic import BaseModel, Field
...
@@ -8,6 +8,8 @@ from pydantic import BaseModel, Field
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
import
torch
class
ErrorResponse
(
BaseModel
):
class
ErrorResponse
(
BaseModel
):
object
:
str
=
"error"
object
:
str
=
"error"
...
@@ -88,6 +90,21 @@ class ChatCompletionRequest(BaseModel):
...
@@ -88,6 +90,21 @@ class ChatCompletionRequest(BaseModel):
def
to_sampling_params
(
self
)
->
SamplingParams
:
def
to_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
raise
ValueError
(
"Top logprobs must be set when logprobs is."
)
raise
ValueError
(
"Top logprobs must be set when logprobs is."
)
logits_processors
=
None
if
self
.
logit_bias
:
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias
=
min
(
100
,
max
(
-
100
,
bias
))
logits
[
int
(
token_id
)]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
return
SamplingParams
(
return
SamplingParams
(
n
=
self
.
n
,
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
presence_penalty
=
self
.
presence_penalty
,
...
@@ -111,6 +128,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -111,6 +128,7 @@ class ChatCompletionRequest(BaseModel):
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
)
)
...
@@ -149,6 +167,20 @@ class CompletionRequest(BaseModel):
...
@@ -149,6 +167,20 @@ class CompletionRequest(BaseModel):
def
to_sampling_params
(
self
):
def
to_sampling_params
(
self
):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
logits_processors
=
None
if
self
.
logit_bias
:
def
logit_bias_logits_processor
(
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
self
.
logit_bias
.
items
():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias
=
min
(
100
,
max
(
-
100
,
bias
))
logits
[
int
(
token_id
)]
+=
bias
return
logits
logits_processors
=
[
logit_bias_logits_processor
]
return
SamplingParams
(
return
SamplingParams
(
n
=
self
.
n
,
n
=
self
.
n
,
best_of
=
self
.
best_of
,
best_of
=
self
.
best_of
,
...
@@ -172,6 +204,7 @@ class CompletionRequest(BaseModel):
...
@@ -172,6 +204,7 @@ class CompletionRequest(BaseModel):
spaces_between_special_tokens
=
(
self
.
spaces_between_special_tokens
),
spaces_between_special_tokens
=
(
self
.
spaces_between_special_tokens
),
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
)
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
e0ade06d
...
@@ -39,19 +39,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -39,19 +39,13 @@ class OpenAIServingChat(OpenAIServing):
See https://platform.openai.com/docs/api-reference/chat/create
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following feature
s
:
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
if
request
.
logit_bias
is
not
None
and
len
(
request
.
logit_bias
)
>
0
:
# TODO: support logit_bias in vLLM engine.
return
self
.
create_error_response
(
"logit_bias is not currently supported"
)
try
:
try
:
prompt
=
self
.
tokenizer
.
apply_chat_template
(
prompt
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
request
.
messages
,
conversation
=
request
.
messages
,
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
e0ade06d
...
@@ -264,10 +264,9 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -264,10 +264,9 @@ class OpenAIServingCompletion(OpenAIServing):
See https://platform.openai.com/docs/api-reference/completions/create
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature
s
:
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
- suffix (the language models we currently support do not support
suffix)
suffix)
- logit_bias (to be supported by vLLM engine)
"""
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
...
@@ -277,9 +276,6 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -277,9 +276,6 @@ class OpenAIServingCompletion(OpenAIServing):
if
request
.
suffix
is
not
None
:
if
request
.
suffix
is
not
None
:
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"suffix is not currently supported"
)
"suffix is not currently supported"
)
if
request
.
logit_bias
is
not
None
and
len
(
request
.
logit_bias
)
>
0
:
return
self
.
create_error_response
(
"logit_bias is not currently supported"
)
model_name
=
request
.
model
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment