Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
120157fd
Unverified
Commit
120157fd
authored
Mar 16, 2024
by
Simon Mo
Committed by
GitHub
Mar 16, 2024
Browse files
Support arbitrary json_object in OpenAI and Context Free Grammar (#3211)
parent
8e67598a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
176 additions
and
49 deletions
+176
-49
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+50
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+9
-0
vllm/model_executor/guided_decoding.py
vllm/model_executor/guided_decoding.py
+41
-13
vllm/model_executor/guided_logits_processors.py
vllm/model_executor/guided_logits_processors.py
+76
-36
No files found.
tests/entrypoints/test_openai_server.py
View file @
120157fd
...
...
@@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_json
=
TEST_SCHEMA
))
async
def
test_response_format_json_object
(
server
,
client
:
openai
.
AsyncOpenAI
):
resp
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
(
'what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}'
)
}],
response_format
=
{
"type"
:
"json_object"
})
content
=
resp
.
choices
[
0
].
message
.
content
loaded
=
json
.
loads
(
content
)
assert
loaded
==
{
"result"
:
2
},
loaded
async
def
test_guided_grammar
(
server
,
client
:
openai
.
AsyncOpenAI
):
simple_sql_grammar
=
"""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
"""
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
(
"Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"
),
temperature
=
1.0
,
max_tokens
=
500
,
extra_body
=
dict
(
guided_grammar
=
simple_sql_grammar
))
content
=
completion
.
choices
[
0
].
text
# use Lark to parse the output, and make sure it's a valid parse tree
from
lark
import
Lark
parser
=
Lark
(
simple_sql_grammar
)
parser
.
parse
(
content
)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth
=
"SELECT col_1 from table_1 where col_1 = 1"
.
replace
(
" "
,
""
)
assert
content
.
strip
()
==
ground_truth
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
vllm/entrypoints/openai/protocol.py
View file @
120157fd
...
...
@@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
completion_tokens
:
Optional
[
int
]
=
0
class
ResponseFormat
(
BaseModel
):
# type must be "json_object" or "text"
type
:
str
=
Literal
[
"text"
,
"json_object"
]
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
messages
:
List
[
Dict
[
str
,
str
]]
...
...
@@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
guided_grammar
:
Optional
[
str
]
=
None
response_format
:
Optional
[
ResponseFormat
]
=
None
def
to_sampling_params
(
self
)
->
SamplingParams
:
if
self
.
logprobs
and
not
self
.
top_logprobs
:
...
...
@@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
guided_grammar
:
Optional
[
str
]
=
None
response_format
:
Optional
[
ResponseFormat
]
=
None
def
to_sampling_params
(
self
):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
...
...
vllm/model_executor/guided_decoding.py
View file @
120157fd
...
...
@@ -6,19 +6,50 @@ from functools import lru_cache
from
json
import
dumps
as
json_dumps
from
re
import
escape
as
regex_escape
from
typing
import
Union
,
Tuple
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
ChatCompletionRequest
)
from
vllm.model_executor.guided_logits_processors
import
(
JSONLogitsProcessor
,
RegexLogitsProcessor
)
RegexLogitsProcessor
,
CFGLogitsProcessor
)
class
GuidedDecodingMode
(
Enum
):
JSON
=
"json"
REGEX
=
"regex"
CHOICE
=
"choice"
GRAMMAR
=
"grammar"
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR
=
r
"""
?start: object | array
?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null
array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : UNESCAPED_STRING ":" value
%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""
global_thread_pool
=
None
# used for generating logits processor fsm
...
...
@@ -57,9 +88,6 @@ def _get_guide_and_mode(
)
->
Tuple
[
str
,
GuidedDecodingMode
]:
if
request
.
guided_json
:
if
not
isinstance
(
request
.
guided_json
,
(
str
,
dict
,
BaseModel
)):
raise
TypeError
(
"JSON schema must be str, dict, or BaseModel"
)
json
=
request
.
guided_json
if
isinstance
(
json
,
dict
):
# turn dict into hashable string
...
...
@@ -69,33 +97,33 @@ def _get_guide_and_mode(
# with the same fields will get hashed the same
json
=
str
(
json
.
__signature__
)
return
json
,
GuidedDecodingMode
.
JSON
elif
request
.
guided_regex
:
if
not
isinstance
(
request
.
guided_regex
,
str
):
raise
TypeError
(
"Regex must be string"
)
return
request
.
guided_regex
,
GuidedDecodingMode
.
REGEX
elif
request
.
guided_choice
:
if
not
isinstance
(
request
.
guided_choice
,
list
):
raise
TypeError
(
"Choices must be a list"
)
# choice just uses regex
choices
=
[
regex_escape
(
str
(
choice
))
for
choice
in
request
.
guided_choice
]
choices_regex
=
"("
+
"|"
.
join
(
choices
)
+
")"
return
choices_regex
,
GuidedDecodingMode
.
CHOICE
elif
request
.
guided_grammar
:
return
request
.
guided_grammar
,
GuidedDecodingMode
.
GRAMMAR
elif
(
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_object"
):
return
JSON_GRAMMAR
,
GuidedDecodingMode
.
GRAMMAR
else
:
return
None
,
None
@
lru_cache
(
maxsize
=
32
)
def
_get_cached_logits_processor
(
guide
:
str
,
tokenizer
,
def
_get_cached_logits_processor
(
guide
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
mode
:
GuidedDecodingMode
):
if
mode
==
GuidedDecodingMode
.
JSON
:
return
JSONLogitsProcessor
(
guide
,
tokenizer
)
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
elif
mode
==
GuidedDecodingMode
.
GRAMMAR
:
return
CFGLogitsProcessor
(
guide
,
tokenizer
)
else
:
raise
ValueError
(
f
"Unknown guided decoding mode
{
mode
}
"
)
vllm/model_executor/guided_logits_processors.py
View file @
120157fd
...
...
@@ -16,30 +16,60 @@
import
json
import
math
from
collections
import
defaultdict
from
typing
import
Union
,
DefaultDict
,
Dict
,
List
,
Optional
from
typing
import
Union
,
DefaultDict
,
Dict
,
List
,
Optional
,
Callable
import
torch
from
pydantic
import
BaseModel
from
outlines.fsm.fsm
import
RegexFSM
from
transformers
import
PreTrainedTokenizerBase
from
outlines.fsm.fsm
import
RegexFSM
,
CFGFSM
from
outlines.fsm.json_schema
import
build_regex_from_schema
class
Regex
LogitsProcessor
:
class
Base
LogitsProcessor
:
def
__init__
(
self
,
regex_string
:
str
,
t
okenizer
):
"""
Compile the FSM that drives the regex-structured generation
.
def
adapt_tokenizer
(
self
,
tokenizer
:
PreTrainedT
okenizer
Base
):
"""
Adapt vLLM's tokenizer to use to compile the FSM
.
Parameters
----------
regex_string
A string that represents a regular express
ion
tokenizer
The model's tokenizer
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addit
ion
we need to handle the missing spaces to Llama's
tokenizer
to be
able to compile FSMs for this model.
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
self
.
fsm
=
fsm
if
getattr
(
tokenizer
,
"_outlines_adapted"
,
False
):
return
tokenizer
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
change_decoder
(
decoder
:
Callable
[[
List
[
int
]],
str
]
)
->
Callable
[[
List
[
int
]],
List
[
str
]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
return
[
decoder
(
inp_tokens
)]
return
new_decoder
tokenizer
.
convert_token_to_string
=
convert_token_to_string
tokenizer
.
decode
=
change_decoder
(
tokenizer
.
decode
)
setattr
(
tokenizer
,
"_outlines_adapted"
,
True
)
# noqa: B010
return
tokenizer
def
init_state
(
self
):
"""Initialize the FSM states."""
...
...
@@ -69,38 +99,30 @@ class RegexLogitsProcessor:
return
scores
def
adapt_tokenizer
(
self
,
tokenizer
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
class
RegexLogitsProcessor
(
BaseLogitsProcessor
):
"""
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
__init__
(
self
,
regex_string
:
str
,
tokenizer
:
PreTrainedTokenizerBase
):
"""Compile the FSM that drives the regex-structured generation.
tokenizer
.
convert_token_to_string
=
convert_token_to_string
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
return
tokenizer
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
self
.
fsm
=
fsm
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
def
__init__
(
self
,
schema
:
Union
[
str
,
Dict
,
BaseModel
],
tokenizer
,
tokenizer
:
PreTrainedTokenizerBase
,
whitespace_pattern
:
Optional
[
str
]
=
None
):
"""Compile the FSM that drives the JSON-guided generation.
...
...
@@ -130,3 +152,21 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
f
"the JSON Schema specification"
)
regex_string
=
build_regex_from_schema
(
schema_str
,
whitespace_pattern
)
super
().
__init__
(
regex_string
,
tokenizer
)
class
CFGLogitsProcessor
(
BaseLogitsProcessor
):
def
__init__
(
self
,
cfg
:
str
,
tokenizer
:
PreTrainedTokenizerBase
):
"""Compile the FSM that drives the context free grammar generation.
Parameters
----------
cfg
A string that represents a context-free grammar
tokenizer
The model's tokenizer
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
fsm
=
CFGFSM
(
cfg
,
tokenizer
)
self
.
fsm
=
fsm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment