Unverified Commit 77a318bd authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[V1][Core] Support MistralTokenizer for Structured Output (#14625)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
parent 80e78d02
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json import json
import re import re
from typing import Any
import jsonschema import jsonschema
import pytest import pytest
...@@ -10,17 +13,27 @@ from vllm.entrypoints.llm import LLM ...@@ -10,17 +13,27 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
@pytest.fixture
def model_name():
return [
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
]
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_completion(monkeypatch, sample_json_schema, def test_guided_json_completion(
guided_decoding_backend: str): monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024) llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0, sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(
...@@ -50,9 +63,13 @@ def test_guided_json_completion(monkeypatch, sample_json_schema, ...@@ -50,9 +63,13 @@ def test_guided_json_completion(monkeypatch, sample_json_schema,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_object(monkeypatch, guided_decoding_backend: str): def test_guided_json_object(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024) llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0, sampling_params = SamplingParams(temperature=1.0,
max_tokens=100, max_tokens=100,
n=2, n=2,
...@@ -84,10 +101,14 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str): ...@@ -84,10 +101,14 @@ def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, def test_guided_json_unsupported_schema(
guided_decoding_backend: str): monkeypatch: pytest.MonkeyPatch,
unsupported_json_schema: dict[str, Any],
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024) llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0, sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000, max_tokens=1000,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(
...@@ -107,10 +128,14 @@ def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, ...@@ -107,10 +128,14 @@ def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, def test_guided_grammar_ebnf(
guided_decoding_backend: str): monkeypatch: pytest.MonkeyPatch,
sample_sql_ebnf: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024) llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
...@@ -145,10 +170,14 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, ...@@ -145,10 +170,14 @@ def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_lark(monkeypatch, sample_sql_lark, def test_guided_grammar_lark(
guided_decoding_backend: str): monkeypatch: pytest.MonkeyPatch,
sample_sql_lark: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024) llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
...@@ -188,10 +217,13 @@ def test_guided_grammar_lark(monkeypatch, sample_sql_lark, ...@@ -188,10 +217,13 @@ def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_ebnf_invalid(monkeypatch, def test_guided_grammar_ebnf_invalid(
guided_decoding_backend: str): monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024) llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
...@@ -212,9 +244,14 @@ def test_guided_grammar_ebnf_invalid(monkeypatch, ...@@ -212,9 +244,14 @@ def test_guided_grammar_ebnf_invalid(monkeypatch,
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): def test_guided_regex(
monkeypatch: pytest.MonkeyPatch,
sample_regex: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024) llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(
...@@ -243,10 +280,14 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): ...@@ -243,10 +280,14 @@ def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
def test_guided_choice_completion(monkeypatch, sample_guided_choice, def test_guided_choice_completion(
guided_decoding_backend: str): monkeypatch: pytest.MonkeyPatch,
sample_guided_choice: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024) llm = LLM(model=model_name, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8, sampling_params = SamplingParams(temperature=0.8,
top_p=0.95, top_p=0.95,
guided_decoding=GuidedDecodingParams( guided_decoding=GuidedDecodingParams(
......
...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader from vllm.utils import LazyLoader
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
...@@ -40,8 +41,40 @@ class StructuredOutputManager: ...@@ -40,8 +41,40 @@ class StructuredOutputManager:
tokenizer_group.ping() tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None) tokenizer = tokenizer_group.get_lora_tokenizer(None)
tokenizer_info = xgr.TokenizerInfo.from_huggingface( if isinstance(tokenizer, MistralTokenizer):
tokenizer, vocab_size=self.vocab_size) # NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
encoded_vocab = [
token for token, _ in sorted(
tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo(
encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
# The default max_workers if not specified is the number of CPUs * 5, # The default max_workers if not specified is the number of CPUs * 5,
...@@ -51,7 +84,9 @@ class StructuredOutputManager: ...@@ -51,7 +84,9 @@ class StructuredOutputManager:
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask( self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size) self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size,
)
self.init_complete = True self.init_complete = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment