Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3845d82
Unverified
Commit
c3845d82
authored
May 01, 2024
by
Robert Caulk
Committed by
GitHub
Apr 30, 2024
Browse files
Allow user to define whitespace pattern for outlines (#4305)
parent
a822eb34
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
8 deletions
+21
-8
tests/entrypoints/test_guided_processors.py
tests/entrypoints/test_guided_processors.py
+3
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+10
-0
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+5
-3
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+3
-4
No files found.
tests/entrypoints/test_guided_processors.py
View file @
c3845d82
...
@@ -57,7 +57,9 @@ def test_guided_logits_processors():
...
@@ -57,7 +57,9 @@ def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'HuggingFaceH4/zephyr-7b-beta'
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'HuggingFaceH4/zephyr-7b-beta'
)
regex_LP
=
RegexLogitsProcessor
(
TEST_REGEX
,
tokenizer
)
regex_LP
=
RegexLogitsProcessor
(
TEST_REGEX
,
tokenizer
)
json_LP
=
JSONLogitsProcessor
(
TEST_SCHEMA
,
tokenizer
)
json_LP
=
JSONLogitsProcessor
(
TEST_SCHEMA
,
tokenizer
,
whitespace_pattern
=
None
)
regex_LP
.
init_state
()
regex_LP
.
init_state
()
token_ids
=
tokenizer
.
encode
(
token_ids
=
tokenizer
.
encode
(
...
...
vllm/entrypoints/openai/protocol.py
View file @
c3845d82
...
@@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend "
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"
))
"'outlines' / 'lm-format-enforcer'"
))
guided_whitespace_pattern
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, will override the default whitespace pattern "
"for guided json decoding."
))
# doc: end-chat-completion-extra-params
# doc: end-chat-completion-extra-params
...
@@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend "
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"
))
"'outlines' / 'lm-format-enforcer'"
))
guided_whitespace_pattern
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, will override the default whitespace pattern "
"for guided json decoding."
))
# doc: end-completion-extra-params
# doc: end-completion-extra-params
...
...
vllm/model_executor/guided_decoding/outlines_decoding.py
View file @
c3845d82
...
@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
...
@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
result
=
await
loop
.
run_in_executor
(
global_thread_pool
,
result
=
await
loop
.
run_in_executor
(
global_thread_pool
,
_get_cached_logits_processor
,
guide
,
_get_cached_logits_processor
,
guide
,
tokenizer
,
mode
)
tokenizer
,
mode
,
request
.
guided_whitespace_pattern
)
logits_processor
=
copy
(
result
)
logits_processor
=
copy
(
result
)
# reset logits processor's internal state
# reset logits processor's internal state
...
@@ -117,9 +118,10 @@ def _get_guide_and_mode(
...
@@ -117,9 +118,10 @@ def _get_guide_and_mode(
@
lru_cache
(
maxsize
=
32
)
@
lru_cache
(
maxsize
=
32
)
def
_get_cached_logits_processor
(
guide
:
str
,
def
_get_cached_logits_processor
(
guide
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
mode
:
GuidedDecodingMode
):
mode
:
GuidedDecodingMode
,
whitespace_pattern
:
Union
[
str
,
None
]):
if
mode
==
GuidedDecodingMode
.
JSON
:
if
mode
==
GuidedDecodingMode
.
JSON
:
return
JSONLogitsProcessor
(
guide
,
tokenizer
)
return
JSONLogitsProcessor
(
guide
,
tokenizer
,
whitespace_pattern
)
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
elif
mode
==
GuidedDecodingMode
.
GRAMMAR
:
elif
mode
==
GuidedDecodingMode
.
GRAMMAR
:
...
...
vllm/model_executor/guided_decoding/outlines_logits_processors.py
View file @
c3845d82
...
@@ -18,7 +18,7 @@ import json
...
@@ -18,7 +18,7 @@ import json
import
math
import
math
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Union
import
torch
import
torch
from
outlines.fsm.fsm
import
CFGFSM
,
FSM
,
RegexFSM
from
outlines.fsm.fsm
import
CFGFSM
,
FSM
,
RegexFSM
...
@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
...
@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
def
__init__
(
self
,
def
__init__
(
self
,
schema
:
Union
[
str
,
Dict
,
BaseModel
],
schema
:
Union
[
str
,
Dict
,
BaseModel
],
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
whitespace_pattern
:
Opt
ion
al
[
str
]
=
None
):
whitespace_pattern
:
Un
ion
[
str
,
None
]
):
"""Compile the FSM that drives the JSON-guided generation.
"""Compile the FSM that drives the JSON-guided generation.
Parameters
Parameters
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment