Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
654bc5ca
Unverified
Commit
654bc5ca
authored
Aug 03, 2024
by
Yihuan Bu
Committed by
GitHub
Aug 04, 2024
Browse files
Support for guided decoding for offline LLM (#6878)
Co-authored-by:
Cyrus Leung
<
cyrus.tl.leung@gmail.com
>
parent
825b0448
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
352 additions
and
12 deletions
+352
-12
docs/source/conf.py
docs/source/conf.py
+1
-0
tests/entrypoints/conftest.py
tests/entrypoints/conftest.py
+21
-1
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+142
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+43
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+20
-6
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+24
-2
vllm/model_executor/guided_decoding/guided_fields.py
vllm/model_executor/guided_decoding/guided_fields.py
+38
-0
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
...l_executor/guided_decoding/lm_format_enforcer_decoding.py
+39
-0
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+24
-2
No files found.
docs/source/conf.py
View file @
654bc5ca
...
@@ -111,6 +111,7 @@ autodoc_mock_imports = [
...
@@ -111,6 +111,7 @@ autodoc_mock_imports = [
"tqdm"
,
"tqdm"
,
"tensorizer"
,
"tensorizer"
,
"pynvml"
,
"pynvml"
,
"outlines"
,
]
]
for
mock_target
in
autodoc_mock_imports
:
for
mock_target
in
autodoc_mock_imports
:
...
...
tests/entrypoints/
openai/
conftest.py
→
tests/entrypoints/conftest.py
View file @
654bc5ca
import
pytest
import
pytest
@
pytest
.
fixture
def
sample_prompts
():
return
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
@
pytest
.
fixture
def
sample_token_ids
():
return
[
[
0
],
[
0
,
1
],
[
0
,
2
,
1
],
[
0
,
3
,
1
,
2
],
]
@
pytest
.
fixture
@
pytest
.
fixture
def
sample_regex
():
def
sample_regex
():
return
(
r
"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
return
(
r
"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
...
@@ -66,4 +86,4 @@ column: "col_1" | "col_2"
...
@@ -66,4 +86,4 @@ column: "col_1" | "col_2"
table: "table_1" | "table_2"
table: "table_1" | "table_2"
condition: column "=" number
condition: column "=" number
number: "1" | "2"
number: "1" | "2"
"""
)
"""
)
\ No newline at end of file
tests/entrypoints/llm/test_guided_generate.py
0 → 100644
View file @
654bc5ca
import
json
import
re
import
weakref
import
jsonschema
import
pytest
from
vllm.entrypoints.llm
import
LLM
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
...conftest
import
cleanup
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
@
pytest
.
fixture
(
scope
=
"module"
)
def
llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup
()
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_regex
(
sample_regex
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
)
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example IPv4 address with this regex:
{
sample_regex
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
,
guided_options_request
=
dict
(
guided_regex
=
sample_regex
))
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
generated_text
)
assert
generated_text
is
not
None
assert
re
.
fullmatch
(
sample_regex
,
generated_text
)
is
not
None
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_json_completion
(
sample_json_schema
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
)
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
sample_json_schema
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
,
guided_options_request
=
dict
(
guided_json
=
sample_json_schema
))
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
output_json
=
json
.
loads
(
generated_text
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
sample_json_schema
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_choice_completion
(
sample_guided_choice
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
)
outputs
=
llm
.
generate
(
prompts
=
"The best language for type-safe systems programming is "
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
,
guided_options_request
=
dict
(
guided_choice
=
sample_guided_choice
))
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
generated_text
)
assert
generated_text
is
not
None
assert
generated_text
in
sample_guided_choice
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_grammar
(
sample_sql_statements
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
)
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"
),
sampling_params
=
sampling_params
,
use_tqdm
=
True
,
guided_options_request
=
dict
(
guided_grammar
=
sample_sql_statements
))
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
# use Lark to parse the output, and make sure it's a valid parse tree
from
lark
import
Lark
parser
=
Lark
(
sample_sql_statements
)
parser
.
parse
(
generated_text
)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth
=
"SELECT col_1 from table_1 where col_1 = 1"
.
replace
(
" "
,
""
)
assert
generated_text
.
strip
()
==
ground_truth
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
vllm/entrypoints/llm.py
View file @
654bc5ca
...
@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
...
@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt
)
parse_and_batch_prompt
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
GuidedDecodingRequest
,
get_local_guided_decoding_logits_processor
)
from
vllm.model_executor.guided_decoding.guided_fields
import
LLMGuidedOptions
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
...
@@ -262,6 +265,8 @@ class LLM:
...
@@ -262,6 +265,8 @@ class LLM:
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
guided_options_request
:
Optional
[
Union
[
LLMGuidedOptions
,
GuidedDecodingRequest
]]
=
None
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
...
@@ -303,6 +308,14 @@ class LLM:
...
@@ -303,6 +308,14 @@ class LLM:
else
:
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
if
isinstance
(
guided_options_request
,
dict
):
if
len
(
guided_options_request
)
>
1
:
raise
ValueError
(
"You can only use one guided decoding but multiple is "
f
"specified:
{
guided_options_request
}
"
)
guided_options_request
=
GuidedDecodingRequest
(
**
guided_options_request
)
if
sampling_params
is
None
:
if
sampling_params
is
None
:
# Use default sampling params.
# Use default sampling params.
sampling_params
=
SamplingParams
()
sampling_params
=
SamplingParams
()
...
@@ -311,7 +324,8 @@ class LLM:
...
@@ -311,7 +324,8 @@ class LLM:
inputs
=
inputs
,
inputs
=
inputs
,
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
guided_options
=
guided_options_request
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
...
@@ -508,6 +522,7 @@ class LLM:
...
@@ -508,6 +522,7 @@ class LLM:
Sequence
[
PoolingParams
]],
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
)
->
None
:
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
if
isinstance
(
inputs
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
...
@@ -523,6 +538,15 @@ class LLM:
...
@@ -523,6 +538,15 @@ class LLM:
raise
ValueError
(
"The lengths of prompts and lora_request "
raise
ValueError
(
"The lengths of prompts and lora_request "
"must be the same."
)
"must be the same."
)
if
isinstance
(
params
,
list
):
params
=
[
self
.
_add_guided_processor
(
param
,
guided_options
)
if
isinstance
(
param
,
SamplingParams
)
else
param
for
param
in
params
]
elif
isinstance
(
params
,
SamplingParams
):
params
=
self
.
_add_guided_processor
(
params
,
guided_options
)
# Add requests to the engine.
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inputs
):
for
i
,
request_inputs
in
enumerate
(
inputs
):
self
.
_add_request
(
self
.
_add_request
(
...
@@ -548,6 +572,24 @@ class LLM:
...
@@ -548,6 +572,24 @@ class LLM:
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
)
def
_add_guided_processor
(
self
,
params
:
SamplingParams
,
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
):
if
guided_options
:
if
guided_options
.
guided_decoding_backend
is
None
:
decoding_config
=
self
.
llm_engine
.
get_decoding_config
()
guided_options
.
guided_decoding_backend
=
(
decoding_config
.
guided_decoding_backend
)
guided_logits_processor
=
get_local_guided_decoding_logits_processor
(
#noqa
guided_options
.
guided_decoding_backend
,
guided_options
,
self
.
get_tokenizer
())
if
guided_logits_processor
:
if
params
.
logits_processors
is
None
:
params
.
logits_processors
=
[]
params
.
logits_processors
.
append
(
guided_logits_processor
)
return
params
def
_run_engine
(
def
_run_engine
(
self
,
*
,
use_tqdm
:
bool
self
,
*
,
use_tqdm
:
bool
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
...
...
vllm/entrypoints/openai/protocol.py
View file @
654bc5ca
# Adapted from
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
import
time
from
argparse
import
Namespace
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
import
torch
...
@@ -14,6 +15,23 @@ from vllm.pooling_params import PoolingParams
...
@@ -14,6 +15,23 @@ from vllm.pooling_params import PoolingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO
=
Namespace
(
min
=-
9223372036854775808
,
max
=
9223372036854775807
)
try
:
from
sphinx.ext.autodoc.mock
import
_MockModule
if
isinstance
(
torch
,
_MockModule
):
_LONG_INFO
=
_MOCK_LONG_INFO
else
:
_LONG_INFO
=
torch
.
iinfo
(
torch
.
long
)
except
ModuleNotFoundError
:
_LONG_INFO
=
torch
.
iinfo
(
torch
.
long
)
assert
_LONG_INFO
.
min
==
_MOCK_LONG_INFO
.
min
assert
_LONG_INFO
.
max
==
_MOCK_LONG_INFO
.
max
class
OpenAIBaseModel
(
BaseModel
):
class
OpenAIBaseModel
(
BaseModel
):
# OpenAI API does not allow extra fields
# OpenAI API does not allow extra fields
...
@@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
n
:
Optional
[
int
]
=
1
n
:
Optional
[
int
]
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
response_format
:
Optional
[
ResponseFormat
]
=
None
response_format
:
Optional
[
ResponseFormat
]
=
None
seed
:
Optional
[
int
]
=
Field
(
None
,
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
ge
=
torch
.
iinfo
(
torch
.
long
).
min
,
le
=
torch
.
iinfo
(
torch
.
long
).
max
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
stream_options
:
Optional
[
StreamOptions
]
=
None
...
@@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens
:
Optional
[
int
]
=
16
max_tokens
:
Optional
[
int
]
=
16
n
:
int
=
1
n
:
int
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
seed
:
Optional
[
int
]
=
Field
(
None
,
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
ge
=
torch
.
iinfo
(
torch
.
long
).
min
,
le
=
torch
.
iinfo
(
torch
.
long
).
max
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
stream_options
:
Optional
[
StreamOptions
]
=
None
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
654bc5ca
...
@@ -3,9 +3,10 @@ from typing import Optional, Union
...
@@ -3,9 +3,10 @@ from typing import Optional, Union
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionNamedToolChoiceParam
,
ChatCompletionRequest
,
ChatCompletionNamedToolChoiceParam
,
ChatCompletionRequest
,
CompletionRequest
)
CompletionRequest
)
from
vllm.model_executor.guided_decoding.
lm_format_enforcer_decoding
import
(
from
vllm.model_executor.guided_decoding.
guided_fields
import
(
get_lm_format_enforcer_g
uided
_d
ecoding
_logits_processor
)
G
uided
D
ecoding
Request
)
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
get_local_outlines_guided_decoding_logits_processor
,
get_outlines_guided_decoding_logits_processor
)
get_outlines_guided_decoding_logits_processor
)
from
vllm.sampling_params
import
LogitsProcessor
from
vllm.sampling_params
import
LogitsProcessor
...
@@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
...
@@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
return
await
get_outlines_guided_decoding_logits_processor
(
return
await
get_outlines_guided_decoding_logits_processor
(
request
,
tokenizer
)
request
,
tokenizer
)
if
guided_decoding_backend
==
'lm-format-enforcer'
:
if
guided_decoding_backend
==
'lm-format-enforcer'
:
from
vllm.model_executor.guided_decoding.lm_format_enforcer_decoding
import
(
# noqa
get_lm_format_enforcer_guided_decoding_logits_processor
)
return
await
get_lm_format_enforcer_guided_decoding_logits_processor
(
return
await
get_lm_format_enforcer_guided_decoding_logits_processor
(
request
,
tokenizer
)
request
,
tokenizer
)
...
@@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
...
@@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
"Must be one of 'outlines, 'lm-format-enforcer'"
)
"Must be one of 'outlines, 'lm-format-enforcer'"
)
def
get_local_guided_decoding_logits_processor
(
guided_decoding_backend
:
str
,
guided_options
:
GuidedDecodingRequest
,
tokenizer
)
->
Optional
[
LogitsProcessor
]:
# request = _adapt_request_for_tool_use(request)
if
guided_decoding_backend
==
'outlines'
:
return
get_local_outlines_guided_decoding_logits_processor
(
guided_options
,
tokenizer
)
if
guided_decoding_backend
==
'lm-format-enforcer'
:
from
vllm.model_executor.guided_decoding.lm_format_enforcer_decoding
import
(
# noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor
)
return
get_local_lm_format_enforcer_guided_decoding_logits_processor
(
guided_options
,
tokenizer
)
raise
ValueError
(
f
"Unknown guided decoding backend '
{
guided_decoding_backend
}
'. "
"Must be one of 'outlines, 'lm-format-enforcer'"
)
def
_adapt_request_for_tool_use
(
request
:
Union
[
CompletionRequest
,
def
_adapt_request_for_tool_use
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]):
ChatCompletionRequest
]):
# the legacy completion API does not support tool use
# the legacy completion API does not support tool use
...
...
vllm/model_executor/guided_decoding/guided_fields.py
0 → 100644
View file @
654bc5ca
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
TypedDict
,
Union
from
pydantic
import
BaseModel
class
LLMGuidedOptions
(
TypedDict
,
total
=
False
):
guided_json
:
Union
[
Dict
,
BaseModel
,
str
]
guided_regex
:
str
guided_choice
:
List
[
str
]
guided_grammar
:
str
guided_decoding_backend
:
str
guided_whitespace_pattern
:
str
guided_json_object
:
bool
@
dataclass
class
GuidedDecodingRequest
:
"""One of the fields will be used to retrieve the logit processor."""
guided_json
:
Optional
[
Union
[
Dict
,
BaseModel
,
str
]]
=
None
guided_regex
:
Optional
[
str
]
=
None
guided_choice
:
Optional
[
List
[
str
]]
=
None
guided_grammar
:
Optional
[
str
]
=
None
guided_decoding_backend
:
Optional
[
str
]
=
None
guided_whitespace_pattern
:
Optional
[
str
]
=
None
guided_json_object
:
Optional
[
bool
]
=
None
def
__post_init__
(
self
):
"""Validate that some fields are mutually exclusive."""
guide_count
=
sum
([
self
.
guided_json
is
not
None
,
self
.
guided_regex
is
not
None
,
self
.
guided_choice
is
not
None
,
self
.
guided_grammar
is
not
None
,
self
.
guided_json_object
is
not
None
])
if
guide_count
>
1
:
raise
ValueError
(
"You can only use one kind of guided decoding but multiple are "
f
"specified:
{
self
.
__dict__
}
"
)
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
View file @
654bc5ca
...
@@ -12,7 +12,10 @@ from transformers import PreTrainedTokenizerBase
...
@@ -12,7 +12,10 @@ from transformers import PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
CompletionRequest
)
from
vllm.model_executor.guided_decoding.guided_fields
import
(
GuidedDecodingRequest
)
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
from
vllm.model_executor.guided_decoding.outlines_decoding
import
(
get_local_outlines_guided_decoding_logits_processor
,
get_outlines_guided_decoding_logits_processor
)
get_outlines_guided_decoding_logits_processor
)
from
vllm.sampling_params
import
LogitsProcessor
from
vllm.sampling_params
import
LogitsProcessor
...
@@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
...
@@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
return
logits_processor
return
logits_processor
def
get_local_lm_format_enforcer_guided_decoding_logits_processor
(
guided_options
:
GuidedDecodingRequest
,
tokenizer
)
->
Optional
[
LogitsProcessor
]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data
=
_cached_build_vllm_token_enforcer_tokenizer_data
(
tokenizer
)
character_level_parser
:
CharacterLevelParser
if
guided_options
.
guided_json
:
schema
=
_normalize_json_schema_object
(
guided_options
.
guided_json
)
character_level_parser
=
JsonSchemaParser
(
schema
)
elif
guided_options
.
guided_choice
:
character_level_parser
=
UnionParser
(
[
StringParser
(
choice
)
for
choice
in
guided_options
.
guided_choice
])
elif
guided_options
.
guided_regex
:
character_level_parser
=
RegexParser
(
guided_options
.
guided_regex
)
elif
guided_options
.
guided_grammar
:
# CFG grammar not supported by LMFE, revert to outlines
return
get_local_outlines_guided_decoding_logits_processor
(
guided_options
,
tokenizer
)
elif
guided_options
.
guided_json_object
:
# None means any json object
character_level_parser
=
JsonSchemaParser
(
None
)
else
:
return
None
logits_processor
=
build_vllm_logits_processor
(
tokenizer_data
,
character_level_parser
)
return
logits_processor
def
_normalize_json_schema_object
(
schema
:
Union
[
str
,
dict
,
BaseModel
])
->
dict
:
def
_normalize_json_schema_object
(
schema
:
Union
[
str
,
dict
,
BaseModel
])
->
dict
:
if
isinstance
(
schema
,
str
):
if
isinstance
(
schema
,
str
):
return
json_loads
(
schema
)
return
json_loads
(
schema
)
...
...
vllm/model_executor/guided_decoding/outlines_decoding.py
View file @
654bc5ca
...
@@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizerBase
...
@@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
)
CompletionRequest
)
from
vllm.model_executor.guided_decoding.guided_fields
import
(
GuidedDecodingRequest
)
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
from
vllm.model_executor.guided_decoding.outlines_logits_processors
import
(
CFGLogitsProcessor
,
JSONLogitsProcessor
,
RegexLogitsProcessor
)
CFGLogitsProcessor
,
JSONLogitsProcessor
,
RegexLogitsProcessor
)
...
@@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor(
...
@@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor(
mode
,
request
.
guided_whitespace_pattern
)
mode
,
request
.
guided_whitespace_pattern
)
def
get_local_outlines_guided_decoding_logits_processor
(
guided_options
:
GuidedDecodingRequest
,
tokenizer
:
PreTrainedTokenizerBase
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
,
CFGLogitsProcessor
,
None
]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide
,
mode
=
_get_guide_and_mode
(
guided_options
)
if
not
guide
or
not
mode
:
return
None
return
_get_logits_processor
(
guide
,
tokenizer
,
mode
,
guided_options
.
guided_whitespace_pattern
)
def
_get_guide_and_mode
(
def
_get_guide_and_mode
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
,
GuidedDecodingRequest
]
)
->
Union
[
Tuple
[
str
,
GuidedDecodingMode
],
Tuple
[
None
,
None
]]:
)
->
Union
[
Tuple
[
str
,
GuidedDecodingMode
],
Tuple
[
None
,
None
]]:
if
request
.
guided_json
:
if
request
.
guided_json
:
...
@@ -102,7 +123,8 @@ def _get_guide_and_mode(
...
@@ -102,7 +123,8 @@ def _get_guide_and_mode(
return
choices_regex
,
GuidedDecodingMode
.
CHOICE
return
choices_regex
,
GuidedDecodingMode
.
CHOICE
elif
request
.
guided_grammar
:
elif
request
.
guided_grammar
:
return
request
.
guided_grammar
,
GuidedDecodingMode
.
GRAMMAR
return
request
.
guided_grammar
,
GuidedDecodingMode
.
GRAMMAR
elif
(
request
.
response_format
is
not
None
elif
(
not
isinstance
(
request
,
GuidedDecodingRequest
)
and
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"json_object"
):
and
request
.
response_format
.
type
==
"json_object"
):
return
JSON_GRAMMAR
,
GuidedDecodingMode
.
GRAMMAR
return
JSON_GRAMMAR
,
GuidedDecodingMode
.
GRAMMAR
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment