Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05434764
Unverified
Commit
05434764
authored
Apr 16, 2024
by
Noam Gat
Committed by
GitHub
Apr 16, 2024
Browse files
LM Format Enforcer Guided Decoding Support (#3868)
Co-authored-by:
Simon Mo
<
simon.mo@hey.com
>
parent
4e7ee664
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
304 additions
and
87 deletions
+304
-87
requirements-common.txt
requirements-common.txt
+1
-0
tests/entrypoints/test_guided_processors.py
tests/entrypoints/test_guided_processors.py
+39
-3
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+50
-19
vllm/config.py
vllm/config.py
+21
-5
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+15
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+7
-3
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+12
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+5
-1
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+5
-1
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+25
-0
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
...l_executor/guided_decoding/lm_format_enforcer_decoding.py
+69
-0
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+3
-4
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+52
-48
No files found.
requirements-common.txt
View file @
05434764
...
...
@@ -11,6 +11,7 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.9.3
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
tests/entrypoints/test_guided_processors.py
View file @
05434764
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
import
pytest
import
torch
from
transformers
import
AutoTokenizer
from
vllm.model_executor.guided_logits_processors
import
(
JSONLogitsProcessor
,
RegexLogitsProcessor
)
from
vllm.entrypoints.openai.protocol
import
CompletionRequest
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
JSONLogitsProcessor
,
RegexLogitsProcessor
)
TEST_SCHEMA
=
{
"type"
:
"object"
,
...
...
@@ -73,3 +76,36 @@ def test_guided_logits_processors():
json_LP
(
token_ids
,
tensor
)
assert
tensor
.
shape
==
original_tensor
.
shape
assert
not
torch
.
allclose
(
tensor
,
original_tensor
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_logits_processor_black_box
(
backend
:
str
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'HuggingFaceH4/zephyr-7b-beta'
)
token_ids
=
tokenizer
.
encode
(
f
"Give an example IPv4 address with this regex:
{
TEST_REGEX
}
"
)
regex_request
=
CompletionRequest
(
model
=
'test'
,
prompt
=
token_ids
,
guided_regex
=
TEST_REGEX
)
regex_lp
=
await
get_guided_decoding_logits_processor
(
backend
,
regex_request
,
tokenizer
)
assert
regex_lp
is
not
None
tensor
=
torch
.
rand
(
32000
)
original_tensor
=
torch
.
clone
(
tensor
)
tensor
=
regex_lp
(
token_ids
,
tensor
)
assert
tensor
.
shape
==
original_tensor
.
shape
assert
not
torch
.
allclose
(
tensor
,
original_tensor
)
token_ids
=
tokenizer
.
encode
(
f
"Give an employee profile that fits this schema:
{
TEST_SCHEMA
}
"
)
json_request
=
CompletionRequest
(
model
=
'test'
,
prompt
=
token_ids
,
guided_json
=
TEST_SCHEMA
)
json_lp
=
await
get_guided_decoding_logits_processor
(
backend
,
json_request
,
tokenizer
)
assert
json_lp
is
not
None
tensor
=
torch
.
rand
(
32000
)
original_tensor
=
torch
.
clone
(
tensor
)
tensor
=
json_lp
(
token_ids
,
tensor
)
assert
tensor
.
shape
==
original_tensor
.
shape
assert
not
torch
.
allclose
(
tensor
,
original_tensor
)
tests/entrypoints/test_openai_server.py
View file @
05434764
...
...
@@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert
first_response
!=
completion
.
choices
[
0
].
text
async
def
test_guided_json_completion
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_json_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
f
"Give an example JSON for an employee profile "
...
...
@@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
n
=
3
,
temperature
=
1.0
,
max_tokens
=
500
,
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
))
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
,
guided_decoding_backend
=
guided_decoding_backend
))
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
3
...
...
@@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
TEST_SCHEMA
)
async
def
test_guided_json_chat
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_json_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
...
...
@@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
500
,
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
))
max_tokens
=
1000
,
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
,
guided_decoding_backend
=
guided_decoding_backend
))
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
json1
=
json
.
loads
(
message
.
content
)
...
...
@@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
500
,
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
))
max_tokens
=
1000
,
extra_body
=
dict
(
guided_json
=
TEST_SCHEMA
,
guided_decoding_backend
=
guided_decoding_backend
))
message
=
chat_completion
.
choices
[
0
].
message
assert
message
.
content
is
not
None
json2
=
json
.
loads
(
message
.
content
)
...
...
@@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
assert
json1
[
"age"
]
!=
json2
[
"age"
]
async
def
test_guided_regex_completion
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_regex_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
f
"Give an example IPv4 address with this regex:
{
TEST_REGEX
}
"
,
n
=
3
,
temperature
=
1.0
,
max_tokens
=
20
,
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
))
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_decoding_backend
=
guided_decoding_backend
))
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
3
...
...
@@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
assert
re
.
fullmatch
(
TEST_REGEX
,
completion
.
choices
[
i
].
text
)
is
not
None
async
def
test_guided_regex_chat
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_regex_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
...
...
@@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
20
,
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
))
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_decoding_backend
=
guided_decoding_backend
))
ip1
=
chat_completion
.
choices
[
0
].
message
.
content
assert
ip1
is
not
None
assert
re
.
fullmatch
(
TEST_REGEX
,
ip1
)
is
not
None
...
...
@@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
20
,
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
))
extra_body
=
dict
(
guided_regex
=
TEST_REGEX
,
guided_decoding_backend
=
guided_decoding_backend
))
ip2
=
chat_completion
.
choices
[
0
].
message
.
content
assert
ip2
is
not
None
assert
re
.
fullmatch
(
TEST_REGEX
,
ip2
)
is
not
None
assert
ip1
!=
ip2
async
def
test_guided_choice_completion
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_choice_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
"The best language for type-safe systems programming is "
,
n
=
2
,
temperature
=
1.0
,
max_tokens
=
10
,
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
))
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
,
guided_decoding_backend
=
guided_decoding_backend
))
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
2
...
...
@@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
assert
completion
.
choices
[
i
].
text
in
TEST_CHOICE
async
def
test_guided_choice_chat
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_choice_chat
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
...
...
@@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
10
,
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
))
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
,
guided_decoding_backend
=
guided_decoding_backend
))
choice1
=
chat_completion
.
choices
[
0
].
message
.
content
assert
choice1
in
TEST_CHOICE
...
...
@@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
10
,
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
))
extra_body
=
dict
(
guided_choice
=
TEST_CHOICE
,
guided_decoding_backend
=
guided_decoding_backend
))
choice2
=
chat_completion
.
choices
[
0
].
message
.
content
assert
choice2
in
TEST_CHOICE
assert
choice1
!=
choice2
async
def
test_guided_decoding_type_error
(
server
,
client
:
openai
.
AsyncOpenAI
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
async
def
test_guided_decoding_type_error
(
server
,
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
):
with
pytest
.
raises
(
openai
.
BadRequestError
):
_
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
"Give an example JSON that fits this schema: 42"
,
extra_body
=
dict
(
guided_json
=
42
))
extra_body
=
dict
(
guided_json
=
42
,
guided_decoding_backend
=
guided_decoding_backend
))
messages
=
[{
"role"
:
"system"
,
...
...
vllm/config.py
View file @
05434764
...
...
@@ -66,8 +66,8 @@ class ModelConfig:
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
...
...
@@ -422,7 +422,7 @@ class CacheConfig:
@
dataclass
class
TokenizerPoolConfig
:
"""Configuration for the tokenizer pool.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
...
...
@@ -446,9 +446,9 @@ class TokenizerPoolConfig:
tokenizer_pool_extra_config
:
Optional
[
Union
[
str
,
dict
]]
)
->
Optional
[
"TokenizerPoolConfig"
]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
...
...
@@ -1079,6 +1079,21 @@ def _get_and_verify_max_len(
return
int
(
max_model_len
)
@
dataclass
class
DecodingConfig
:
"""Dataclass which contains the decoding strategy of the engine"""
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend
:
str
=
'outlines'
def
__post_init__
(
self
):
valid_guided_backends
=
[
'outlines'
,
'lm-format-enforcer'
]
backend
=
self
.
guided_decoding_backend
if
backend
not
in
valid_guided_backends
:
raise
ValueError
(
f
"Invalid guided_decoding_backend '
{
backend
}
,"
f
"must be one of
{
valid_guided_backends
}
"
)
@
dataclass
(
frozen
=
True
)
class
EngineConfig
:
"""Dataclass which contains all engine-related configuration. This
...
...
@@ -1093,6 +1108,7 @@ class EngineConfig:
lora_config
:
Optional
[
LoRAConfig
]
vision_language_config
:
Optional
[
VisionLanguageConfig
]
speculative_config
:
Optional
[
SpeculativeConfig
]
decoding_config
:
Optional
[
DecodingConfig
]
tensorizer_config
:
Optional
[
TensorizerConfig
]
def
__post_init__
(
self
):
...
...
vllm/engine/arg_utils.py
View file @
05434764
...
...
@@ -5,9 +5,9 @@ import os
from
dataclasses
import
dataclass
from
typing
import
BinaryIO
,
Optional
,
Union
from
vllm.config
import
(
CacheConfig
,
De
viceConfig
,
Engine
Config
,
LoRA
Config
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TensorizerConfig
,
from
vllm.config
import
(
CacheConfig
,
De
coding
Config
,
Device
Config
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TensorizerConfig
,
TokenizerPoolConfig
,
VisionLanguageConfig
)
from
vllm.model_executor.tensorizer_loader
import
TensorizerArgs
from
vllm.utils
import
str_to_int_tuple
...
...
@@ -80,6 +80,7 @@ class EngineArgs:
scheduler_delay_factor
:
float
=
0.0
enable_chunked_prefill
:
bool
=
False
guided_decoding_backend
:
str
=
'outlines'
# Speculative decoding configuration.
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
...
...
@@ -200,6 +201,13 @@ class EngineArgs:
default
=
EngineArgs
.
max_model_len
,
help
=
'model context length. If unspecified, '
'will be automatically derived from the model.'
)
parser
.
add_argument
(
'--guided-decoding-backend'
,
type
=
str
,
default
=
'outlines'
,
choices
=
[
'outlines'
,
'lm-format-enforcer'
],
help
=
'Which engine will be used for guided decoding'
' (JSON schema / regex etc)'
)
# Parallel arguments
parser
.
add_argument
(
'--worker-use-ray'
,
action
=
'store_true'
,
...
...
@@ -511,6 +519,9 @@ class EngineArgs:
else
:
vision_language_config
=
None
decoding_config
=
DecodingConfig
(
guided_decoding_backend
=
self
.
guided_decoding_backend
)
return
EngineConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
parallel_config
=
parallel_config
,
...
...
@@ -519,6 +530,7 @@ class EngineArgs:
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
speculative_config
=
speculative_config
,
decoding_config
=
decoding_config
,
tensorizer_config
=
tensorizer_config
)
...
...
vllm/engine/llm_engine.py
View file @
05434764
...
...
@@ -4,9 +4,10 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
from
transformers
import
PreTrainedTokenizer
import
vllm
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TensorizerConfig
,
VisionLanguageConfig
)
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TensorizerConfig
,
VisionLanguageConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
StatLogger
,
Stats
...
...
@@ -74,6 +75,7 @@ class LLMEngine:
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
decoding_config
:
Optional
[
DecodingConfig
],
tensorizer_config
:
Optional
[
TensorizerConfig
],
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
...
...
@@ -100,6 +102,7 @@ class LLMEngine:
f
"kv_cache_dtype=
{
cache_config
.
cache_dtype
}
, "
f
"quantization_param_path=
{
model_config
.
quantization_param_path
}
, "
f
"device_config=
{
device_config
.
device
}
, "
f
"decoding_config=
{
decoding_config
!
r
}
, "
f
"seed=
{
model_config
.
seed
}
)"
)
# TODO(woosuk): Print more configs in debug mode.
...
...
@@ -111,6 +114,7 @@ class LLMEngine:
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
speculative_config
=
speculative_config
self
.
decoding_config
=
decoding_config
or
DecodingConfig
()
self
.
tensorizer_config
=
tensorizer_config
self
.
log_stats
=
log_stats
...
...
vllm/entrypoints/openai/protocol.py
View file @
05434764
...
...
@@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel):
description
=
(
"If specified, the output will follow the context free grammar."
),
)
guided_decoding_backend
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"
))
# doc: end-chat-completion-extra-params
...
...
@@ -265,6 +271,12 @@ class CompletionRequest(BaseModel):
description
=
(
"If specified, the output will follow the context free grammar."
),
)
guided_decoding_backend
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"
))
# doc: end-completion-extra-params
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
05434764
...
...
@@ -68,9 +68,13 @@ class OpenAIServingChat(OpenAIServing):
request
,
prompt
=
prompt
)
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
decoding_config
=
self
.
engine
.
engine
.
decoding_config
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logits_processor
=
(
await
get_guided_decoding_logits_processor
(
request
,
await
self
.
engine
.
get_tokenizer
()))
guided_decoding_backend
,
request
,
await
self
.
engine
.
get_tokenizer
()))
if
guided_decode_logits_processor
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
05434764
...
...
@@ -88,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing):
try
:
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
decoding_config
=
self
.
engine
.
engine
.
decoding_config
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
guided_decode_logit_processor
=
(
await
get_guided_decoding_logits_processor
(
request
,
await
self
.
engine
.
get_tokenizer
()))
guided_decoding_backend
,
request
,
await
self
.
engine
.
get_tokenizer
()))
if
guided_decode_logit_processor
is
not
None
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
...
...
vllm/model_executor/guided_decoding/__init__.py
0 → 100644
View file @
05434764
from
typing
import
Optional
,
Union
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.model_executor.guided_decoding.lm_format_enforcer_decoding
import
(
get_lm_format_enforcer_guided_decoding_logits_processor
)
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
get_outlines_guided_decoding_logits_processor
)
from
vllm.sampling_params
import
LogitsProcessor
async
def
get_guided_decoding_logits_processor
(
guided_decoding_backend
:
str
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Optional
[
LogitsProcessor
]:
if
guided_decoding_backend
==
'outlines'
:
return
await
get_outlines_guided_decoding_logits_processor
(
request
,
tokenizer
)
if
guided_decoding_backend
==
'lm-format-enforcer'
:
return
await
get_lm_format_enforcer_guided_decoding_logits_processor
(
request
,
tokenizer
)
raise
ValueError
(
f
"Unknown guided decoding backend '
{
guided_decoding_backend
}
'. "
"Must be one of 'outlines, 'lm-format-enforcer'"
)
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
0 → 100644
View file @
05434764
from
functools
import
lru_cache
from
json
import
loads
as
json_loads
from
typing
import
Optional
,
Union
from
lmformatenforcer
import
(
CharacterLevelParser
,
JsonSchemaParser
,
RegexParser
,
StringParser
,
TokenEnforcerTokenizerData
,
UnionParser
)
from
lmformatenforcer.integrations.vllm
import
(
build_vllm_logits_processor
,
build_vllm_token_enforcer_tokenizer_data
)
from
pydantic
import
BaseModel
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
get_outlines_guided_decoding_logits_processor
)
from
vllm.sampling_params
import
LogitsProcessor
async
def
get_lm_format_enforcer_guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Optional
[
LogitsProcessor
]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data
=
_cached_build_vllm_token_enforcer_tokenizer_data
(
tokenizer
)
character_level_parser
:
CharacterLevelParser
if
request
.
guided_json
:
schema
=
_normalize_json_schema_object
(
request
.
guided_json
)
character_level_parser
=
JsonSchemaParser
(
schema
)
elif
request
.
guided_choice
:
character_level_parser
=
UnionParser
(
[
StringParser
(
choice
)
for
choice
in
request
.
guided_choice
])
elif
request
.
guided_regex
:
character_level_parser
=
RegexParser
(
request
.
guided_regex
)
elif
request
.
guided_grammar
:
# CFG grammar not supported by LMFE, revert to outlines
return
await
get_outlines_guided_decoding_logits_processor
(
request
,
tokenizer
)
elif
(
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_object"
):
character_level_parser
=
JsonSchemaParser
(
None
)
# None means any json object
else
:
return
None
logits_processor
=
build_vllm_logits_processor
(
tokenizer_data
,
character_level_parser
)
return
logits_processor
def
_normalize_json_schema_object
(
schema
:
Union
[
str
,
dict
,
BaseModel
])
->
dict
:
if
isinstance
(
schema
,
str
):
return
json_loads
(
schema
)
if
isinstance
(
schema
,
dict
):
return
schema
if
isinstance
(
schema
,
BaseModel
):
return
schema
.
model_json_schema
()
@
lru_cache
def
_cached_build_vllm_token_enforcer_tokenizer_data
(
tokenizer
:
PreTrainedTokenizerBase
)
->
TokenEnforcerTokenizerData
:
return
build_vllm_token_enforcer_tokenizer_data
(
tokenizer
)
vllm/model_executor/guided_decoding.py
→
vllm/model_executor/guided_decoding
/outlines_decoding
.py
View file @
05434764
...
...
@@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
from
vllm.model_executor.guided_logits_processors
import
(
CFGLogitsProcessor
,
JSONLogitsProcessor
,
RegexLogitsProcessor
)
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
CFGLogitsProcessor
,
JSONLogitsProcessor
,
RegexLogitsProcessor
)
class
GuidedDecodingMode
(
Enum
):
...
...
@@ -54,7 +53,7 @@ pair : UNESCAPED_STRING ":" value
global_thread_pool
=
None
# used for generating logits processor fsm
async
def
get_guided_decoding_logits_processor
(
async
def
get_
outlines_
guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
]:
"""
...
...
vllm/model_executor/guided_logits_processors.py
→
vllm/model_executor/guided_
decoding/outlines_
logits_processors.py
View file @
05434764
...
...
@@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
json
import
math
from
collections
import
defaultdict
from
functools
import
lru_cache
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
import
torch
...
...
@@ -27,50 +29,6 @@ from transformers import PreTrainedTokenizerBase
class
BaseLogitsProcessor
:
def
adapt_tokenizer
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if
getattr
(
tokenizer
,
"_outlines_adapted"
,
False
):
return
tokenizer
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
change_decoder
(
decoder
:
Callable
[[
List
[
int
]],
str
]
)
->
Callable
[[
List
[
int
]],
List
[
str
]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
return
[
decoder
(
inp_tokens
)]
return
new_decoder
tokenizer
.
convert_token_to_string
=
convert_token_to_string
tokenizer
.
decode
=
change_decoder
(
tokenizer
.
decode
)
setattr
(
tokenizer
,
"_outlines_adapted"
,
True
)
# noqa: B010
return
tokenizer
def
init_state
(
self
):
"""Initialize the FSM states."""
self
.
fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
...
...
@@ -78,7 +36,6 @@ class BaseLogitsProcessor:
def
__call__
(
self
,
input_ids
:
List
[
int
],
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id
=
hash
(
tuple
(
input_ids
))
if
len
(
input_ids
)
==
0
:
...
...
@@ -96,7 +53,6 @@ class BaseLogitsProcessor:
device
=
scores
.
device
)
mask
[
allowed_tokens
]
=
0
scores
.
add_
(
mask
)
return
scores
...
...
@@ -113,7 +69,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
tokenizer
=
_
adapt_tokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
self
.
fsm
=
fsm
...
...
@@ -167,6 +123,54 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
tokenizer
=
_
adapt_tokenizer
(
tokenizer
)
fsm
=
CFGFSM
(
cfg
,
tokenizer
)
self
.
fsm
=
fsm
@
lru_cache
def
_adapt_tokenizer
(
tokenizer
:
PreTrainedTokenizerBase
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if
getattr
(
tokenizer
,
"_outlines_adapted"
,
False
):
return
tokenizer
tokenizer
=
copy
.
deepcopy
(
tokenizer
)
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
change_decoder
(
decoder
:
Callable
[[
List
[
int
]],
str
])
->
Callable
[[
List
[
int
]],
List
[
str
]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def
new_decoder
(
inp_tokens
:
List
[
int
])
->
List
[
str
]:
return
[
decoder
(
inp_tokens
)]
return
new_decoder
tokenizer
.
convert_token_to_string
=
convert_token_to_string
tokenizer
.
decode
=
change_decoder
(
tokenizer
.
decode
)
setattr
(
tokenizer
,
"_outlines_adapted"
,
True
)
# noqa: B010
return
tokenizer
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment