Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
120157fd
Unverified
Commit
120157fd
authored
Mar 16, 2024
by
Simon Mo
Committed by
GitHub
Mar 16, 2024
Browse files
Support arbitrary json_object in OpenAI and Context Free Grammar (#3211)
parent
8e67598a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
176 additions
and
49 deletions
+176
-49
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+50
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+9
-0
vllm/model_executor/guided_decoding.py
vllm/model_executor/guided_decoding.py
+41
-13
vllm/model_executor/guided_logits_processors.py
vllm/model_executor/guided_logits_processors.py
+76
-36
No files found.
tests/entrypoints/test_openai_server.py
View file @
120157fd
...
@@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
...
@@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_json
=
TEST_SCHEMA
))
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_json
=
TEST_SCHEMA
))
async
def
test_response_format_json_object
(
server
,
client
:
openai
.
AsyncOpenAI
):
resp
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
(
'what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}'
)
}],
response_format
=
{
"type"
:
"json_object"
})
content
=
resp
.
choices
[
0
].
message
.
content
loaded
=
json
.
loads
(
content
)
assert
loaded
==
{
"result"
:
2
},
loaded
async
def
test_guided_grammar
(
server
,
client
:
openai
.
AsyncOpenAI
):
simple_sql_grammar
=
"""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
"""
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
(
"Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"
),
temperature
=
1.0
,
max_tokens
=
500
,
extra_body
=
dict
(
guided_grammar
=
simple_sql_grammar
))
content
=
completion
.
choices
[
0
].
text
# use Lark to parse the output, and make sure it's a valid parse tree
from
lark
import
Lark
parser
=
Lark
(
simple_sql_grammar
)
parser
.
parse
(
content
)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth
=
"SELECT col_1 from table_1 where col_1 = 1"
.
replace
(
" "
,
""
)
assert
content
.
strip
()
==
ground_truth
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
vllm/entrypoints/openai/protocol.py
View file @
120157fd
...
@@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
...
@@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
completion_tokens
:
Optional
[
int
]
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
ResponseFormat
(
BaseModel
):
# type must be "json_object" or "text"
type
:
str
=
Literal
[
"text"
,
"json_object"
]
class
ChatCompletionRequest
(
BaseModel
):
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
model
:
str
messages
:
List
[
Dict
[
str
,
str
]]
messages
:
List
[
Dict
[
str
,
str
]]
...
@@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
...
@@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
None
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
guided_grammar
:
Optional
[
str
]
=
None
response_format
:
Optional
[
ResponseFormat
]
=
None
def
to_sampling_params
(
self
)
->
SamplingParams
:
def
to_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
...
@@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
...
@@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
None
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
guided_grammar
:
Optional
[
str
]
=
None
response_format
:
Optional
[
ResponseFormat
]
=
None
def
to_sampling_params
(
self
):
def
to_sampling_params
(
self
):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
...
...
vllm/model_executor/guided_decoding.py
View file @
120157fd
...
@@ -6,19 +6,50 @@ from functools import lru_cache
...
@@ -6,19 +6,50 @@ from functools import lru_cache
from
json
import
dumps
as
json_dumps
from
json
import
dumps
as
json_dumps
from
re
import
escape
as
regex_escape
from
re
import
escape
as
regex_escape
from
typing
import
Union
,
Tuple
from
typing
import
Union
,
Tuple
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
ChatCompletionRequest
)
ChatCompletionRequest
)
from
vllm.model_executor.guided_logits_processors
import
(
JSONLogitsProcessor
,
from
vllm.model_executor.guided_logits_processors
import
(
JSONLogitsProcessor
,
RegexLogitsProcessor
)
RegexLogitsProcessor
,
CFGLogitsProcessor
)
class
GuidedDecodingMode
(
Enum
):
class
GuidedDecodingMode
(
Enum
):
JSON
=
"json"
JSON
=
"json"
REGEX
=
"regex"
REGEX
=
"regex"
CHOICE
=
"choice"
CHOICE
=
"choice"
GRAMMAR
=
"grammar"
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR
=
r
"""
?start: object | array
?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value
%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""
global_thread_pool
=
None
# used for generating logits processor fsm
global_thread_pool
=
None
# used for generating logits processor fsm
...
@@ -57,9 +88,6 @@ def _get_guide_and_mode(
...
@@ -57,9 +88,6 @@ def _get_guide_and_mode(
)
->
Tuple
[
str
,
GuidedDecodingMode
]:
)
->
Tuple
[
str
,
GuidedDecodingMode
]:
if
request
.
guided_json
:
if
request
.
guided_json
:
if
not
isinstance
(
request
.
guided_json
,
(
str
,
dict
,
BaseModel
)):
raise
TypeError
(
"JSON schema must be str, dict, or BaseModel"
)
json
=
request
.
guided_json
json
=
request
.
guided_json
if
isinstance
(
json
,
dict
):
if
isinstance
(
json
,
dict
):
# turn dict into hashable string
# turn dict into hashable string
...
@@ -69,33 +97,33 @@ def _get_guide_and_mode(
...
@@ -69,33 +97,33 @@ def _get_guide_and_mode(
# with the same fields will get hashed the same
# with the same fields will get hashed the same
json
=
str
(
json
.
__signature__
)
json
=
str
(
json
.
__signature__
)
return
json
,
GuidedDecodingMode
.
JSON
return
json
,
GuidedDecodingMode
.
JSON
elif
request
.
guided_regex
:
elif
request
.
guided_regex
:
if
not
isinstance
(
request
.
guided_regex
,
str
):
raise
TypeError
(
"Regex must be string"
)
return
request
.
guided_regex
,
GuidedDecodingMode
.
REGEX
return
request
.
guided_regex
,
GuidedDecodingMode
.
REGEX
elif
request
.
guided_choice
:
elif
request
.
guided_choice
:
if
not
isinstance
(
request
.
guided_choice
,
list
):
raise
TypeError
(
"Choices must be a list"
)
# choice just uses regex
# choice just uses regex
choices
=
[
choices
=
[
regex_escape
(
str
(
choice
))
for
choice
in
request
.
guided_choice
regex_escape
(
str
(
choice
))
for
choice
in
request
.
guided_choice
]
]
choices_regex
=
"("
+
"|"
.
join
(
choices
)
+
")"
choices_regex
=
"("
+
"|"
.
join
(
choices
)
+
")"
return
choices_regex
,
GuidedDecodingMode
.
CHOICE
return
choices_regex
,
GuidedDecodingMode
.
CHOICE
elif
request
.
guided_grammar
:
return
request
.
guided_grammar
,
GuidedDecodingMode
.
GRAMMAR
elif
(
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_object"
):
return
JSON_GRAMMAR
,
GuidedDecodingMode
.
GRAMMAR
else
:
else
:
return
None
,
None
return
None
,
None
@
lru_cache
(
maxsize
=
32
)
@
lru_cache
(
maxsize
=
32
)
def
_get_cached_logits_processor
(
guide
:
str
,
tokenizer
,
def
_get_cached_logits_processor
(
guide
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
mode
:
GuidedDecodingMode
):
mode
:
GuidedDecodingMode
):
if
mode
==
GuidedDecodingMode
.
JSON
:
if
mode
==
GuidedDecodingMode
.
JSON
:
return
JSONLogitsProcessor
(
guide
,
tokenizer
)
return
JSONLogitsProcessor
(
guide
,
tokenizer
)
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
elif
mode
==
GuidedDecodingMode
.
GRAMMAR
:
return
CFGLogitsProcessor
(
guide
,
tokenizer
)
else
:
else
:
raise
ValueError
(
f
"Unknown guided decoding mode
{
mode
}
"
)
raise
ValueError
(
f
"Unknown guided decoding mode
{
mode
}
"
)
vllm/model_executor/guided_logits_processors.py
View file @
120157fd
...
@@ -16,30 +16,60 @@
...
@@ -16,30 +16,60 @@
import
json
import
json
import
math
import
math
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Union
,
DefaultDict
,
Dict
,
List
,
Optional
from
typing
import
Union
,
DefaultDict
,
Dict
,
List
,
Optional
,
Callable
import
torch
import
torch
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
outlines.fsm.fsm
import
RegexFSM
from
transformers
import
PreTrainedTokenizerBase
from
outlines.fsm.fsm
import
RegexFSM
,
CFGFSM
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
outlines.fsm.json_schema
import
build_regex_from_schema
class
Regex
LogitsProcessor
:
class
Base
LogitsProcessor
:
def
__init__
(
self
,
regex_string
:
str
,
t
okenizer
):
def
adapt_tokenizer
(
self
,
tokenizer
:
PreTrainedT
okenizer
Base
):
"""
Compile the FSM that drives the regex-structured generation
.
"""
Adapt vLLM's tokenizer to use to compile the FSM
.
Parameters
The API of Outlines tokenizers is slightly different to that of
----------
`transformers`. The decoder of outlines, returns a list whereas
regex_string
the decode of vLLM returns an str. To sync the vLLM decoder with
A string that represents a regular express
ion
outlines internal api, the decoder should be adapted. In addit
ion
tokenizer
we need to handle the missing spaces to Llama's
tokenizer
to be
The model's tokenizer
able to compile FSMs for this model.
"""
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
if
getattr
(
tokenizer
,
"_outlines_adapted"
,
False
):
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
return
tokenizer
self
.
fsm
=
fsm
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
change_decoder
(
decoder
:
Callable
[[
List
[
int
]],
str
]
)
->
Callable
[[
List
[
int
]],
List
[
str
]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
return
[
decoder
(
inp_tokens
)]
return
new_decoder
tokenizer
.
convert_token_to_string
=
convert_token_to_string
tokenizer
.
decode
=
change_decoder
(
tokenizer
.
decode
)
setattr
(
tokenizer
,
"_outlines_adapted"
,
True
)
# noqa: B010
return
tokenizer
def
init_state
(
self
):
def
init_state
(
self
):
"""Initialize the FSM states."""
"""Initialize the FSM states."""
...
@@ -69,38 +99,30 @@ class RegexLogitsProcessor:
...
@@ -69,38 +99,30 @@ class RegexLogitsProcessor:
return
scores
return
scores
def
adapt_tokenizer
(
self
,
tokenizer
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
class
RegexLogitsProcessor
(
BaseLogitsProcessor
):
# A hack to handle missing spaces to HF's Llama tokenizers
def
__init__
(
self
,
regex_string
:
str
,
tokenizer
:
PreTrainedTokenizerBase
):
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
"""Compile the FSM that drives the regex-structured generation.
return
" "
+
string
return
string
tokenizer
.
convert_token_to_string
=
convert_token_to_string
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
return
tokenizer
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
self
.
fsm
=
fsm
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
def
__init__
(
self
,
def
__init__
(
self
,
schema
:
Union
[
str
,
Dict
,
BaseModel
],
schema
:
Union
[
str
,
Dict
,
BaseModel
],
tokenizer
,
tokenizer
:
PreTrainedTokenizerBase
,
whitespace_pattern
:
Optional
[
str
]
=
None
):
whitespace_pattern
:
Optional
[
str
]
=
None
):
"""Compile the FSM that drives the JSON-guided generation.
"""Compile the FSM that drives the JSON-guided generation.
...
@@ -130,3 +152,21 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
...
@@ -130,3 +152,21 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
f
"the JSON Schema specification"
)
f
"the JSON Schema specification"
)
regex_string
=
build_regex_from_schema
(
schema_str
,
whitespace_pattern
)
regex_string
=
build_regex_from_schema
(
schema_str
,
whitespace_pattern
)
super
().
__init__
(
regex_string
,
tokenizer
)
super
().
__init__
(
regex_string
,
tokenizer
)
class
CFGLogitsProcessor
(
BaseLogitsProcessor
):
def
__init__
(
self
,
cfg
:
str
,
tokenizer
:
PreTrainedTokenizerBase
):
"""Compile the FSM that drives the context free grammar generation.
Parameters
----------
cfg
A string that represents a context-free grammar
tokenizer
The model's tokenizer
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
fsm
=
CFGFSM
(
cfg
,
tokenizer
)
self
.
fsm
=
fsm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment