Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05434764
Unverified
Commit
05434764
authored
Apr 16, 2024
by
Noam Gat
Committed by
GitHub
Apr 16, 2024
Browse files
LM Format Enforcer Guided Decoding Support (#3868)
Co-authored-by:
Simon Mo
<
simon.mo@hey.com
>
parent
4e7ee664
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
304 additions
and
87 deletions
+304
-87
requirements-common.txt
requirements-common.txt
+1
-0
tests/entrypoints/test_guided_processors.py
tests/entrypoints/test_guided_processors.py
+39
-3
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+50
-19
vllm/config.py
vllm/config.py
+21
-5
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+15
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+7
-3
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+12
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+5
-1
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+5
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+25
-0
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
...l_executor/guided_decoding/lm_format_enforcer_decoding.py
+69
-0
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+3
-4
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+52
-48
No files found.
requirements-common.txt
View file @
05434764
...
@@ -11,6 +11,7 @@ uvicorn[standard]
...
@@ -11,6 +11,7 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
tiktoken == 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.9.3
outlines == 0.0.34 # Requires torch >= 2.1.0
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
tests/entrypoints/test_guided_processors.py
View file @
05434764
# This unit test should be moved to a new
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
# tests/test_guided_decoding directory.
import
pytest
import
torch
import
torch
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
vllm.model_executor.guided_logits_processors
import
(
JSONLogitsProcessor
,
from
vllm.entrypoints.openai.protocol
import
CompletionRequest
RegexLogitsProcessor
)
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
JSONLogitsProcessor
,
RegexLogitsProcessor
)
TEST_SCHEMA
=
{
TEST_SCHEMA
=
{
"type"
:
"object"
,
"type"
:
"object"
,
...
@@ -73,3 +76,36 @@ def test_guided_logits_processors():
...
@@ -73,3 +76,36 @@ def test_guided_logits_processors():
json_LP
(
token_ids
,
tensor
)
json_LP
(
token_ids
,
tensor
)
assert
tensor
.
shape
==
original_tensor
.
shape
assert
tensor
.
shape
==
original_tensor
.
shape
assert
not
torch
.
allclose
(
tensor
,
original_tensor
)
assert
not
torch
.
allclose
(
tensor
,
original_tensor
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_logits_processor_black_box
(
backend
:
str
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'HuggingFaceH4/zephyr-7b-beta'
)
token_ids
=
tokenizer
.
encode
(
f
"Give an example IPv4 address with this regex:
{
TEST_REGEX
}
"
)
regex_request
=
CompletionRequest
(
model
=
'test'
,
prompt
=
token_ids
,
guided_regex
=
TEST_REGEX
)
regex_lp
=
await
get_guided_decoding_logits_processor
(
backend
,
regex_request
,
tokenizer
)
assert
regex_lp
is
not
None
tensor
=
torch
.
rand
(
32000
)
original_tensor
=
torch
.
clone
(
tensor
)
tensor
=
regex_lp
(
token_ids
,
tensor
)
assert
tensor
.
shape
==
original_tensor
.
shape
assert
not
torch
.
allclose
(
tensor
,
original_tensor
)
token_ids
=
tokenizer
.
encode
(
f
"Give an employee profile that fits this schema:
{
TEST_SCHEMA
}
"
)
json_request
=
CompletionRequest
(
model
=
'test'
,
prompt
=
token_ids
,
guided_json
=
TEST_SCHEMA
)
json_lp
=
await
get_guided_decoding_logits_processor
(
backend
,
json_request
,
tokenizer
)
assert
json_lp
is
not
None
tensor
=
torch
.
rand
(
32000
)
original_tensor
=
torch
.
clone
(
tensor
)
tensor
=
json_lp
(
token_ids
,
tensor
)
assert
tensor
.
shape
==
original_tensor
.
shape
assert
not
torch
.
allclose
(
tensor
,
original_tensor
)
tests/entrypoints/test_openai_server.py
View file @
05434764
...
@@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
...
@@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert
first_response
!=
completion
.
choices
[
0
].
text
assert
first_response
!=
completion
.
choices
[
0
].
text
async
def
test_guided_json_completion
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_json_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
completion
=
await
client
.
completions
.
create
(
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
prompt
=
f
"Give an example JSON for an employee profile "
prompt
=
f
"Give an example JSON for an employee profile "
...
@@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
...
@@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
n
=
3
,
n
=
3
,
temperature
=
1.0
,
temperature
=
1.0
,
max_tokens
=
500
,
max_tokens
=
500
,
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
))
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
,
guided_decoding_backend
=
guided_decoding_backend
))
assert
completion
.
id
is
not
None
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
3
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
3
...
@@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
...
@@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
TEST_SCHEMA
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
TEST_SCHEMA
)
async
def
test_guided_json_chat
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_json_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
messages
=
[{
messages
=
[{
"role"
:
"system"
,
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
"content"
:
"you are a helpful assistant"
...
@@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
...
@@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion
=
await
client
.
chat
.
completions
.
create
(
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
500
,
max_tokens
=
1000
,
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
))
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
,
guided_decoding_backend
=
guided_decoding_backend
))
message
=
chat_completion
.
choices
[
0
].
message
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
assert
message
.
content
is
not
None
json1
=
json
.
loads
(
message
.
content
)
json1
=
json
.
loads
(
message
.
content
)
...
@@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
...
@@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion
=
await
client
.
chat
.
completions
.
create
(
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
500
,
max_tokens
=
1000
,
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
))
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
,
guided_decoding_backend
=
guided_decoding_backend
))
message
=
chat_completion
.
choices
[
0
].
message
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
assert
message
.
content
is
not
None
json2
=
json
.
loads
(
message
.
content
)
json2
=
json
.
loads
(
message
.
content
)
...
@@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
...
@@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
assert
json1
[
"age"
]
!=
json2
[
"age"
]
assert
json1
[
"age"
]
!=
json2
[
"age"
]
async
def
test_guided_regex_completion
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_regex_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
completion
=
await
client
.
completions
.
create
(
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
prompt
=
f
"Give an example IPv4 address with this regex:
{
TEST_REGEX
}
"
,
prompt
=
f
"Give an example IPv4 address with this regex:
{
TEST_REGEX
}
"
,
n
=
3
,
n
=
3
,
temperature
=
1.0
,
temperature
=
1.0
,
max_tokens
=
20
,
max_tokens
=
20
,
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
))
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_decoding_backend
=
guided_decoding_backend
))
assert
completion
.
id
is
not
None
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
3
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
3
...
@@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
...
@@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
assert
re
.
fullmatch
(
TEST_REGEX
,
completion
.
choices
[
i
].
text
)
is
not
None
assert
re
.
fullmatch
(
TEST_REGEX
,
completion
.
choices
[
i
].
text
)
is
not
None
async
def
test_guided_regex_chat
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_regex_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
messages
=
[{
messages
=
[{
"role"
:
"system"
,
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
"content"
:
"you are a helpful assistant"
...
@@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
...
@@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
20
,
max_tokens
=
20
,
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
))
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_decoding_backend
=
guided_decoding_backend
))
ip1
=
chat_completion
.
choices
[
0
].
message
.
content
ip1
=
chat_completion
.
choices
[
0
].
message
.
content
assert
ip1
is
not
None
assert
ip1
is
not
None
assert
re
.
fullmatch
(
TEST_REGEX
,
ip1
)
is
not
None
assert
re
.
fullmatch
(
TEST_REGEX
,
ip1
)
is
not
None
...
@@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
...
@@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
20
,
max_tokens
=
20
,
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
))
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_decoding_backend
=
guided_decoding_backend
))
ip2
=
chat_completion
.
choices
[
0
].
message
.
content
ip2
=
chat_completion
.
choices
[
0
].
message
.
content
assert
ip2
is
not
None
assert
ip2
is
not
None
assert
re
.
fullmatch
(
TEST_REGEX
,
ip2
)
is
not
None
assert
re
.
fullmatch
(
TEST_REGEX
,
ip2
)
is
not
None
assert
ip1
!=
ip2
assert
ip1
!=
ip2
async
def
test_guided_choice_completion
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_choice_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
completion
=
await
client
.
completions
.
create
(
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
prompt
=
"The best language for type-safe systems programming is "
,
prompt
=
"The best language for type-safe systems programming is "
,
n
=
2
,
n
=
2
,
temperature
=
1.0
,
temperature
=
1.0
,
max_tokens
=
10
,
max_tokens
=
10
,
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
))
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
,
guided_decoding_backend
=
guided_decoding_backend
))
assert
completion
.
id
is
not
None
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
2
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
2
...
@@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
...
@@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
assert
completion
.
choices
[
i
].
text
in
TEST_CHOICE
assert
completion
.
choices
[
i
].
text
in
TEST_CHOICE
async
def
test_guided_choice_chat
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_choice_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
messages
=
[{
messages
=
[{
"role"
:
"system"
,
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
"content"
:
"you are a helpful assistant"
...
@@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
...
@@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
10
,
max_tokens
=
10
,
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
))
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
,
guided_decoding_backend
=
guided_decoding_backend
))
choice1
=
chat_completion
.
choices
[
0
].
message
.
content
choice1
=
chat_completion
.
choices
[
0
].
message
.
content
assert
choice1
in
TEST_CHOICE
assert
choice1
in
TEST_CHOICE
...
@@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
...
@@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
messages
,
messages
=
messages
,
max_tokens
=
10
,
max_tokens
=
10
,
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
))
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
,
guided_decoding_backend
=
guided_decoding_backend
))
choice2
=
chat_completion
.
choices
[
0
].
message
.
content
choice2
=
chat_completion
.
choices
[
0
].
message
.
content
assert
choice2
in
TEST_CHOICE
assert
choice2
in
TEST_CHOICE
assert
choice1
!=
choice2
assert
choice1
!=
choice2
async
def
test_guided_decoding_type_error
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_decoding_type_error
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
with
pytest
.
raises
(
openai
.
BadRequestError
):
with
pytest
.
raises
(
openai
.
BadRequestError
):
_
=
await
client
.
completions
.
create
(
_
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
prompt
=
"Give an example JSON that fits this schema: 42"
,
prompt
=
"Give an example JSON that fits this schema: 42"
,
extra_body
=
dict
(
guided_json
=
42
))
extra_body
=
dict
(
guided_json
=
42
,
guided_decoding_backend
=
guided_decoding_backend
))
messages
=
[{
messages
=
[{
"role"
:
"system"
,
"role"
:
"system"
,
...
...
vllm/config.py
View file @
05434764
...
@@ -1079,6 +1079,21 @@ def _get_and_verify_max_len(
...
@@ -1079,6 +1079,21 @@ def _get_and_verify_max_len(
return
int
(
max_model_len
)
return
int
(
max_model_len
)
@
dataclass
class
DecodingConfig
:
"""Dataclass which contains the decoding strategy of the engine"""
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend
:
str
=
'outlines'
def
__post_init__
(
self
):
valid_guided_backends
=
[
'outlines'
,
'lm-format-enforcer'
]
backend
=
self
.
guided_decoding_backend
if
backend
not
in
valid_guided_backends
:
raise
ValueError
(
f
"Invalid guided_decoding_backend '
{
backend
}
,"
f
"must be one of
{
valid_guided_backends
}
"
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
EngineConfig
:
class
EngineConfig
:
"""Dataclass which contains all engine-related configuration. This
"""Dataclass which contains all engine-related configuration. This
...
@@ -1093,6 +1108,7 @@ class EngineConfig:
...
@@ -1093,6 +1108,7 @@ class EngineConfig:
lora_config
:
Optional
[
LoRAConfig
]
lora_config
:
Optional
[
LoRAConfig
]
vision_language_config
:
Optional
[
VisionLanguageConfig
]
vision_language_config
:
Optional
[
VisionLanguageConfig
]
speculative_config
:
Optional
[
SpeculativeConfig
]
speculative_config
:
Optional
[
SpeculativeConfig
]
decoding_config
:
Optional
[
DecodingConfig
]
tensorizer_config
:
Optional
[
TensorizerConfig
]
tensorizer_config
:
Optional
[
TensorizerConfig
]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
...
...
vllm/engine/arg_utils.py
View file @
05434764
...
@@ -5,9 +5,9 @@ import os
...
@@ -5,9 +5,9 @@ import os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
BinaryIO
,
Optional
,
Union
from
typing
import
BinaryIO
,
Optional
,
Union
from
vllm.config
import
(
CacheConfig
,
De
viceConfig
,
Engine
Config
,
LoRA
Config
,
from
vllm.config
import
(
CacheConfig
,
De
coding
Config
,
Device
Config
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SpeculativeConfig
,
TensorizerConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TensorizerConfig
,
TokenizerPoolConfig
,
VisionLanguageConfig
)
TokenizerPoolConfig
,
VisionLanguageConfig
)
from
vllm.model_executor.tensorizer_loader
import
TensorizerArgs
from
vllm.model_executor.tensorizer_loader
import
TensorizerArgs
from
vllm.utils
import
str_to_int_tuple
from
vllm.utils
import
str_to_int_tuple
...
@@ -80,6 +80,7 @@ class EngineArgs:
...
@@ -80,6 +80,7 @@ class EngineArgs:
scheduler_delay_factor
:
float
=
0.0
scheduler_delay_factor
:
float
=
0.0
enable_chunked_prefill
:
bool
=
False
enable_chunked_prefill
:
bool
=
False
guided_decoding_backend
:
str
=
'outlines'
# Speculative decoding configuration.
# Speculative decoding configuration.
speculative_model
:
Optional
[
str
]
=
None
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
...
@@ -200,6 +201,13 @@ class EngineArgs:
...
@@ -200,6 +201,13 @@ class EngineArgs:
default
=
EngineArgs
.
max_model_len
,
default
=
EngineArgs
.
max_model_len
,
help
=
'model context length. If unspecified, '
help
=
'model context length. If unspecified, '
'will be automatically derived from the model.'
)
'will be automatically derived from the model.'
)
parser
.
add_argument
(
'--guided-decoding-backend'
,
type
=
str
,
default
=
'outlines'
,
choices
=
[
'outlines'
,
'lm-format-enforcer'
],
help
=
'Which engine will be used for guided decoding'
' (JSON schema / regex etc)'
)
# Parallel arguments
# Parallel arguments
parser
.
add_argument
(
'--worker-use-ray'
,
parser
.
add_argument
(
'--worker-use-ray'
,
action
=
'store_true'
,
action
=
'store_true'
,
...
@@ -511,6 +519,9 @@ class EngineArgs:
...
@@ -511,6 +519,9 @@ class EngineArgs:
else
:
else
:
vision_language_config
=
None
vision_language_config
=
None
decoding_config
=
DecodingConfig
(
guided_decoding_backend
=
self
.
guided_decoding_backend
)
return
EngineConfig
(
model_config
=
model_config
,
return
EngineConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
parallel_config
=
parallel_config
,
parallel_config
=
parallel_config
,
...
@@ -519,6 +530,7 @@ class EngineArgs:
...
@@ -519,6 +530,7 @@ class EngineArgs:
lora_config
=
lora_config
,
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
vision_language_config
=
vision_language_config
,
speculative_config
=
speculative_config
,
speculative_config
=
speculative_config
,
decoding_config
=
decoding_config
,
tensorizer_config
=
tensorizer_config
)
tensorizer_config
=
tensorizer_config
)
...
...
vllm/engine/llm_engine.py
View file @
05434764
...
@@ -4,9 +4,10 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
...
@@ -4,9 +4,10 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
import
vllm
import
vllm
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
LoRAConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
TensorizerConfig
,
VisionLanguageConfig
)
SpeculativeConfig
,
TensorizerConfig
,
VisionLanguageConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
StatLogger
,
Stats
from
vllm.engine.metrics
import
StatLogger
,
Stats
...
@@ -74,6 +75,7 @@ class LLMEngine:
...
@@ -74,6 +75,7 @@ class LLMEngine:
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
decoding_config
:
Optional
[
DecodingConfig
],
tensorizer_config
:
Optional
[
TensorizerConfig
],
tensorizer_config
:
Optional
[
TensorizerConfig
],
executor_class
:
Type
[
ExecutorBase
],
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
log_stats
:
bool
,
...
@@ -100,6 +102,7 @@ class LLMEngine:
...
@@ -100,6 +102,7 @@ class LLMEngine:
f
"kv_cache_dtype=
{
cache_config
.
cache_dtype
}
, "
f
"kv_cache_dtype=
{
cache_config
.
cache_dtype
}
, "
f
"quantization_param_path=
{
model_config
.
quantization_param_path
}
, "
f
"quantization_param_path=
{
model_config
.
quantization_param_path
}
, "
f
"device_config=
{
device_config
.
device
}
, "
f
"device_config=
{
device_config
.
device
}
, "
f
"decoding_config=
{
decoding_config
!
r
}
, "
f
"seed=
{
model_config
.
seed
}
)"
)
f
"seed=
{
model_config
.
seed
}
)"
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
...
@@ -111,6 +114,7 @@ class LLMEngine:
...
@@ -111,6 +114,7 @@ class LLMEngine:
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
speculative_config
=
speculative_config
self
.
speculative_config
=
speculative_config
self
.
decoding_config
=
decoding_config
or
DecodingConfig
()
self
.
tensorizer_config
=
tensorizer_config
self
.
tensorizer_config
=
tensorizer_config
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
...
...
vllm/entrypoints/openai/protocol.py
View file @
05434764
...
@@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel):
...
@@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel):
description
=
(
description
=
(
"If specified, the output will follow the context free grammar."
),
"If specified, the output will follow the context free grammar."
),
)
)
guided_decoding_backend
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"
))
# doc: end-chat-completion-extra-params
# doc: end-chat-completion-extra-params
...
@@ -265,6 +271,12 @@ class CompletionRequest(BaseModel):
...
@@ -265,6 +271,12 @@ class CompletionRequest(BaseModel):
description
=
(
description
=
(
"If specified, the output will follow the context free grammar."
),
"If specified, the output will follow the context free grammar."
),
)
)
guided_decoding_backend
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"
))
# doc: end-completion-extra-params
# doc: end-completion-extra-params
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
05434764
...
@@ -68,9 +68,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -68,9 +68,13 @@ class OpenAIServingChat(OpenAIServing):
request
,
prompt
=
prompt
)
request
,
prompt
=
prompt
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
lora_request
=
self
.
_maybe_get_lora
(
request
)
decoding_config
=
self
.
engine
.
engine
.
decoding_config
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logits_processor
=
(
guided_decode_logits_processor
=
(
await
get_guided_decoding_logits_processor
(
await
get_guided_decoding_logits_processor
(
request
,
await
self
.
engine
.
get_tokenizer
()))
guided_decoding_backend
,
request
,
await
self
.
engine
.
get_tokenizer
()))
if
guided_decode_logits_processor
:
if
guided_decode_logits_processor
:
if
sampling_params
.
logits_processors
is
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
=
[]
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
05434764
...
@@ -88,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -88,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing):
try
:
try
:
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
lora_request
=
self
.
_maybe_get_lora
(
request
)
decoding_config
=
self
.
engine
.
engine
.
decoding_config
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logit_processor
=
(
guided_decode_logit_processor
=
(
await
get_guided_decoding_logits_processor
(
await
get_guided_decoding_logits_processor
(
request
,
await
self
.
engine
.
get_tokenizer
()))
guided_decoding_backend
,
request
,
await
self
.
engine
.
get_tokenizer
()))
if
guided_decode_logit_processor
is
not
None
:
if
guided_decode_logit_processor
is
not
None
:
if
sampling_params
.
logits_processors
is
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
=
[]
...
...
vllm/model_executor/guided_decoding/__init__.py
0 → 100644
View file @
05434764
from
typing
import
Optional
,
Union
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.model_executor.guided_decoding.lm_format_enforcer_decoding
import
(
get_lm_format_enforcer_guided_decoding_logits_processor
)
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
get_outlines_guided_decoding_logits_processor
)
from
vllm.sampling_params
import
LogitsProcessor
async
def
get_guided_decoding_logits_processor
(
guided_decoding_backend
:
str
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Optional
[
LogitsProcessor
]:
if
guided_decoding_backend
==
'outlines'
:
return
await
get_outlines_guided_decoding_logits_processor
(
request
,
tokenizer
)
if
guided_decoding_backend
==
'lm-format-enforcer'
:
return
await
get_lm_format_enforcer_guided_decoding_logits_processor
(
request
,
tokenizer
)
raise
ValueError
(
f
"Unknown guided decoding backend '
{
guided_decoding_backend
}
'. "
"Must be one of 'outlines, 'lm-format-enforcer'"
)
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
0 → 100644
View file @
05434764
from
functools
import
lru_cache
from
json
import
loads
as
json_loads
from
typing
import
Optional
,
Union
from
lmformatenforcer
import
(
CharacterLevelParser
,
JsonSchemaParser
,
RegexParser
,
StringParser
,
TokenEnforcerTokenizerData
,
UnionParser
)
from
lmformatenforcer.integrations.vllm
import
(
build_vllm_logits_processor
,
build_vllm_token_enforcer_tokenizer_data
)
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
get_outlines_guided_decoding_logits_processor
)
from
vllm.sampling_params
import
LogitsProcessor
async
def
get_lm_format_enforcer_guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Optional
[
LogitsProcessor
]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data
=
_cached_build_vllm_token_enforcer_tokenizer_data
(
tokenizer
)
character_level_parser
:
CharacterLevelParser
if
request
.
guided_json
:
schema
=
_normalize_json_schema_object
(
request
.
guided_json
)
character_level_parser
=
JsonSchemaParser
(
schema
)
elif
request
.
guided_choice
:
character_level_parser
=
UnionParser
(
[
StringParser
(
choice
)
for
choice
in
request
.
guided_choice
])
elif
request
.
guided_regex
:
character_level_parser
=
RegexParser
(
request
.
guided_regex
)
elif
request
.
guided_grammar
:
# CFG grammar not supported by LMFE, revert to outlines
return
await
get_outlines_guided_decoding_logits_processor
(
request
,
tokenizer
)
elif
(
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_object"
):
character_level_parser
=
JsonSchemaParser
(
None
)
# None means any json object
else
:
return
None
logits_processor
=
build_vllm_logits_processor
(
tokenizer_data
,
character_level_parser
)
return
logits_processor
def
_normalize_json_schema_object
(
schema
:
Union
[
str
,
dict
,
BaseModel
])
->
dict
:
if
isinstance
(
schema
,
str
):
return
json_loads
(
schema
)
if
isinstance
(
schema
,
dict
):
return
schema
if
isinstance
(
schema
,
BaseModel
):
return
schema
.
model_json_schema
()
@
lru_cache
def
_cached_build_vllm_token_enforcer_tokenizer_data
(
tokenizer
:
PreTrainedTokenizerBase
)
->
TokenEnforcerTokenizerData
:
return
build_vllm_token_enforcer_tokenizer_data
(
tokenizer
)
vllm/model_executor/guided_decoding.py
→
vllm/model_executor/guided_decoding
/outlines_decoding
.py
View file @
05434764
...
@@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
...
@@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
CompletionRequest
)
from
vllm.model_executor.guided_logits_processors
import
(
CFGLogitsProcessor
,
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
JSONLogitsProcessor
,
CFGLogitsProcessor
,
JSONLogitsProcessor
,
RegexLogitsProcessor
)
RegexLogitsProcessor
)
class
GuidedDecodingMode
(
Enum
):
class
GuidedDecodingMode
(
Enum
):
...
@@ -54,7 +53,7 @@ pair : UNESCAPED_STRING ":" value
...
@@ -54,7 +53,7 @@ pair : UNESCAPED_STRING ":" value
global_thread_pool
=
None
# used for generating logits processor fsm
global_thread_pool
=
None
# used for generating logits processor fsm
async
def
get_guided_decoding_logits_processor
(
async
def
get_
outlines_
guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
]:
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
]:
"""
"""
...
...
vllm/model_executor/guided_logits_processors.py
→
vllm/model_executor/guided_
decoding/outlines_
logits_processors.py
View file @
05434764
...
@@ -13,9 +13,11 @@
...
@@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
import
json
import
json
import
math
import
math
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
lru_cache
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
...
@@ -27,50 +29,6 @@ from transformers import PreTrainedTokenizerBase
...
@@ -27,50 +29,6 @@ from transformers import PreTrainedTokenizerBase
class
BaseLogitsProcessor
:
class
BaseLogitsProcessor
:
def
adapt_tokenizer
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if
getattr
(
tokenizer
,
"_outlines_adapted"
,
False
):
return
tokenizer
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
change_decoder
(
decoder
:
Callable
[[
List
[
int
]],
str
]
)
->
Callable
[[
List
[
int
]],
List
[
str
]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
return
[
decoder
(
inp_tokens
)]
return
new_decoder
tokenizer
.
convert_token_to_string
=
convert_token_to_string
tokenizer
.
decode
=
change_decoder
(
tokenizer
.
decode
)
setattr
(
tokenizer
,
"_outlines_adapted"
,
True
)
# noqa: B010
return
tokenizer
def
init_state
(
self
):
def
init_state
(
self
):
"""Initialize the FSM states."""
"""Initialize the FSM states."""
self
.
fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
self
.
fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
...
@@ -78,7 +36,6 @@ class BaseLogitsProcessor:
...
@@ -78,7 +36,6 @@ class BaseLogitsProcessor:
def
__call__
(
self
,
input_ids
:
List
[
int
],
def
__call__
(
self
,
input_ids
:
List
[
int
],
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Use the FSM to bias the logits before sampling the next token."""
"""Use the FSM to bias the logits before sampling the next token."""
seq_id
=
hash
(
tuple
(
input_ids
))
seq_id
=
hash
(
tuple
(
input_ids
))
if
len
(
input_ids
)
==
0
:
if
len
(
input_ids
)
==
0
:
...
@@ -96,7 +53,6 @@ class BaseLogitsProcessor:
...
@@ -96,7 +53,6 @@ class BaseLogitsProcessor:
device
=
scores
.
device
)
device
=
scores
.
device
)
mask
[
allowed_tokens
]
=
0
mask
[
allowed_tokens
]
=
0
scores
.
add_
(
mask
)
scores
.
add_
(
mask
)
return
scores
return
scores
...
@@ -113,7 +69,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
...
@@ -113,7 +69,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer
The model's tokenizer
"""
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
tokenizer
=
_
adapt_tokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
self
.
fsm
=
fsm
self
.
fsm
=
fsm
...
@@ -167,6 +123,54 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
...
@@ -167,6 +123,54 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer
The model's tokenizer
"""
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
tokenizer
=
_
adapt_tokenizer
(
tokenizer
)
fsm
=
CFGFSM
(
cfg
,
tokenizer
)
fsm
=
CFGFSM
(
cfg
,
tokenizer
)
self
.
fsm
=
fsm
self
.
fsm
=
fsm
@
lru_cache
def
_adapt_tokenizer
(
tokenizer
:
PreTrainedTokenizerBase
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if
getattr
(
tokenizer
,
"_outlines_adapted"
,
False
):
return
tokenizer
tokenizer
=
copy
.
deepcopy
(
tokenizer
)
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
change_decoder
(
decoder
:
Callable
[[
List
[
int
]],
str
])
->
Callable
[[
List
[
int
]],
List
[
str
]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
return
[
decoder
(
inp_tokens
)]
return
new_decoder
tokenizer
.
convert_token_to_string
=
convert_token_to_string
tokenizer
.
decode
=
change_decoder
(
tokenizer
.
decode
)
setattr
(
tokenizer
,
"_outlines_adapted"
,
True
)
# noqa: B010
return
tokenizer
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment