Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
77a318bd
Unverified
Commit
77a318bd
authored
Mar 11, 2025
by
Aaron Pham
Committed by
GitHub
Mar 12, 2025
Browse files
[V1][Core] Support MistralTokenizer for Structured Output (#14625)
Signed-off-by:
Aaron Pham
<
contact@aarnphm.xyz
>
parent
80e78d02
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
102 additions
and
26 deletions
+102
-26
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+64
-23
vllm/v1/structured_output/__init__.py
vllm/v1/structured_output/__init__.py
+38
-3
No files found.
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
77a318bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
json
import
json
import
re
import
re
from
typing
import
Any
import
jsonschema
import
jsonschema
import
pytest
import
pytest
...
@@ -10,17 +13,27 @@ from vllm.entrypoints.llm import LLM
...
@@ -10,17 +13,27 @@ from vllm.entrypoints.llm import LLM
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
MODEL_NAME
=
"Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS_V1
=
[
"xgrammar"
]
GUIDED_DECODING_BACKENDS_V1
=
[
"xgrammar"
]
@
pytest
.
fixture
def
model_name
():
return
[
"Qwen/Qwen2.5-1.5B-Instruct"
,
"mistralai/Ministral-8B-Instruct-2410"
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
def
test_guided_json_completion
(
monkeypatch
,
sample_json_schema
,
def
test_guided_json_completion
(
guided_decoding_backend
:
str
):
monkeypatch
:
pytest
.
MonkeyPatch
,
sample_json_schema
:
dict
[
str
,
Any
],
guided_decoding_backend
:
str
,
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
guided_decoding
=
GuidedDecodingParams
(
...
@@ -50,9 +63,13 @@ def test_guided_json_completion(monkeypatch, sample_json_schema,
...
@@ -50,9 +63,13 @@ def test_guided_json_completion(monkeypatch, sample_json_schema,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
def
test_guided_json_object
(
monkeypatch
,
guided_decoding_backend
:
str
):
def
test_guided_json_object
(
monkeypatch
:
pytest
.
MonkeyPatch
,
guided_decoding_backend
:
str
,
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
100
,
max_tokens
=
100
,
n
=
2
,
n
=
2
,
...
@@ -84,10 +101,14 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
...
@@ -84,10 +101,14 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
def
test_guided_json_unsupported_schema
(
monkeypatch
,
unsupported_json_schema
,
def
test_guided_json_unsupported_schema
(
guided_decoding_backend
:
str
):
monkeypatch
:
pytest
.
MonkeyPatch
,
unsupported_json_schema
:
dict
[
str
,
Any
],
guided_decoding_backend
:
str
,
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
guided_decoding
=
GuidedDecodingParams
(
...
@@ -107,10 +128,14 @@ def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
...
@@ -107,10 +128,14 @@ def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
def
test_guided_grammar_ebnf
(
monkeypatch
,
sample_sql_ebnf
,
def
test_guided_grammar_ebnf
(
guided_decoding_backend
:
str
):
monkeypatch
:
pytest
.
MonkeyPatch
,
sample_sql_ebnf
:
str
,
guided_decoding_backend
:
str
,
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
top_p
=
0.95
,
max_tokens
=
1000
,
max_tokens
=
1000
,
...
@@ -145,10 +170,14 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
...
@@ -145,10 +170,14 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
def
test_guided_grammar_lark
(
monkeypatch
,
sample_sql_lark
,
def
test_guided_grammar_lark
(
guided_decoding_backend
:
str
):
monkeypatch
:
pytest
.
MonkeyPatch
,
sample_sql_lark
:
str
,
guided_decoding_backend
:
str
,
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
top_p
=
0.95
,
max_tokens
=
1000
,
max_tokens
=
1000
,
...
@@ -188,10 +217,13 @@ def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
...
@@ -188,10 +217,13 @@ def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
def
test_guided_grammar_ebnf_invalid
(
monkeypatch
,
def
test_guided_grammar_ebnf_invalid
(
guided_decoding_backend
:
str
):
monkeypatch
:
pytest
.
MonkeyPatch
,
guided_decoding_backend
:
str
,
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
top_p
=
0.95
,
max_tokens
=
1000
,
max_tokens
=
1000
,
...
@@ -212,9 +244,14 @@ def test_guided_grammar_ebnf_invalid(monkeypatch,
...
@@ -212,9 +244,14 @@ def test_guided_grammar_ebnf_invalid(monkeypatch,
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
def
test_guided_regex
(
monkeypatch
,
sample_regex
,
guided_decoding_backend
:
str
):
def
test_guided_regex
(
monkeypatch
:
pytest
.
MonkeyPatch
,
sample_regex
:
str
,
guided_decoding_backend
:
str
,
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
guided_decoding
=
GuidedDecodingParams
(
...
@@ -243,10 +280,14 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
...
@@ -243,10 +280,14 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
GUIDED_DECODING_BACKENDS_V1
)
def
test_guided_choice_completion
(
monkeypatch
,
sample_guided_choice
,
def
test_guided_choice_completion
(
guided_decoding_backend
:
str
):
monkeypatch
:
pytest
.
MonkeyPatch
,
sample_guided_choice
:
str
,
guided_decoding_backend
:
str
,
model_name
:
str
,
):
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
MODEL_NAME
,
max_model_len
=
1024
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
guided_decoding
=
GuidedDecodingParams
(
...
...
vllm/v1/structured_output/__init__.py
View file @
77a318bd
...
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional
...
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.utils
import
LazyLoader
from
vllm.utils
import
LazyLoader
from
vllm.v1.structured_output.grammar
import
Grammar
,
StructuredOutputOptions
from
vllm.v1.structured_output.grammar
import
Grammar
,
StructuredOutputOptions
...
@@ -40,8 +41,40 @@ class StructuredOutputManager:
...
@@ -40,8 +41,40 @@ class StructuredOutputManager:
tokenizer_group
.
ping
()
tokenizer_group
.
ping
()
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
tokenizer_info
=
xgr
.
TokenizerInfo
.
from_huggingface
(
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tokenizer
,
vocab_size
=
self
.
vocab_size
)
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try
:
encoded_vocab
=
[
token
for
token
,
_
in
sorted
(
tokenizer
.
get_vocab
().
items
(),
key
=
lambda
x
:
x
[
1
],
)
]
stop_token_ids
=
None
if
hasattr
(
tokenizer
,
"eos_token_id"
,
)
and
tokenizer
.
eos_token_id
is
not
None
:
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
except
AttributeError
as
e
:
raise
ValueError
(
f
"Cannot get the vocabulary of the tokenizer "
f
"
{
type
(
tokenizer
)
}
. The tokenizer should have a "
"get_vocab method."
)
from
e
tokenizer_info
=
xgr
.
TokenizerInfo
(
encoded_vocab
=
encoded_vocab
,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type
=
xgr
.
VocabType
.
BYTE_FALLBACK
,
vocab_size
=
self
.
vocab_size
,
stop_token_ids
=
stop_token_ids
,
add_prefix_space
=
True
,
)
else
:
tokenizer_info
=
xgr
.
TokenizerInfo
.
from_huggingface
(
tokenizer
,
vocab_size
=
self
.
vocab_size
,
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
)
# The default max_workers if not specified is the number of CPUs * 5,
# The default max_workers if not specified is the number of CPUs * 5,
...
@@ -51,7 +84,9 @@ class StructuredOutputManager:
...
@@ -51,7 +84,9 @@ class StructuredOutputManager:
max_workers
=
max
(
1
,
(
multiprocessing
.
cpu_count
()
+
1
)
//
2
)
max_workers
=
max
(
1
,
(
multiprocessing
.
cpu_count
()
+
1
)
//
2
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
max_workers
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
max_workers
)
self
.
_grammar_bitmask
=
xgr
.
allocate_token_bitmask
(
self
.
_grammar_bitmask
=
xgr
.
allocate_token_bitmask
(
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
,
self
.
vocab_size
)
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
,
self
.
vocab_size
,
)
self
.
init_complete
=
True
self
.
init_complete
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment