"launch/dynamo-run/vscode:/vscode.git/clone" did not exist on "0e77d3442ee227eeaa84ac013633ce67be6b99b8"
Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
......@@ -316,7 +316,7 @@ steps:
source_file_dependencies:
- vllm/
- tests/engine
- tests/tokenization
- tests/tokenizers_
- tests/test_sequence
- tests/test_config
- tests/test_logger
......@@ -324,7 +324,7 @@ steps:
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
- pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45
......
......@@ -282,7 +282,7 @@ steps:
source_file_dependencies:
- vllm/
- tests/engine
- tests/tokenization
- tests/tokenizers_
- tests/test_sequence
- tests/test_config
- tests/test_logger
......@@ -290,7 +290,7 @@ steps:
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
- pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45
......
......@@ -620,7 +620,7 @@ def get_tokenizer(
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
except ImportError as e:
raise ImportError(
"MistralTokenizer requires vllm package.\n"
......
......@@ -216,14 +216,13 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso
# import the required packages
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
# define a reasoning parser and register it to vllm
# the name list in register_module can be used
# in --reasoning-parser.
class ExampleParser(ReasoningParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def extract_reasoning_streaming(
......
......@@ -422,7 +422,7 @@ Here is a summary of a plugin file:
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens
......
......@@ -10,7 +10,7 @@ import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
@pytest.fixture()
......
......@@ -4,9 +4,9 @@
import pytest
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="function")
def default_tokenizer() -> AnyTokenizer:
def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2")
......@@ -7,7 +7,7 @@ import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from ....utils import RemoteOpenAIServer
......@@ -270,14 +270,14 @@ async def test_streaming_product_tool_call():
@pytest.fixture
def qwen_tokenizer() -> AnyTokenizer:
def qwen_tokenizer() -> TokenizerLike:
from vllm.transformers_utils.tokenizer import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
return Hermes2ProToolParser(qwen_tokenizer)
......@@ -291,7 +291,7 @@ def any_chat_request() -> ChatCompletionRequest:
def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: AnyTokenizer,
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
......@@ -323,7 +323,7 @@ def test_hermes_parser_streaming_just_forward_text(
def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: AnyTokenizer,
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
......@@ -357,7 +357,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
def test_hermes_parser_streaming(
qwen_tokenizer: AnyTokenizer,
qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
......
......@@ -7,11 +7,11 @@ import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
@pytest.fixture
def parser(default_tokenizer: AnyTokenizer):
def parser(default_tokenizer: TokenizerLike):
return Llama3JsonToolParser(default_tokenizer)
......
......@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
# Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
......@@ -64,7 +64,7 @@ PYTHON_TAG_FUNCTION_OUTPUT = (
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
......@@ -208,7 +208,7 @@ def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
......@@ -224,7 +224,7 @@ def test_tool_call(
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
......@@ -246,7 +246,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
......
......@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
......@@ -69,7 +69,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
......@@ -188,7 +188,7 @@ def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
......@@ -205,7 +205,7 @@ def test_tool_call(
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
......@@ -228,7 +228,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
......
......@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
......@@ -61,7 +61,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
......@@ -168,7 +168,7 @@ def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
......@@ -185,7 +185,7 @@ def test_tool_call(
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
......@@ -208,7 +208,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
......
......@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
class StreamingToolReconstructor:
......@@ -111,7 +111,7 @@ def run_tool_extraction_nonstreaming(
return tool_parser.extract_tool_calls(model_output, request)
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]:
def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
# Split a string into a series of deltas using the provided tokenizer. Each
# delta will be the string equivalent of a single token.
token_ids = tokenizer.encode(text, add_special_tokens=False)
......
......@@ -28,8 +28,8 @@ from vllm.multimodal.utils import (
encode_image_base64,
encode_video_base64,
)
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
......
......@@ -10,7 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolParser,
)
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from ...utils import check_logprobs_close
......
......@@ -9,7 +9,7 @@ from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from ....conftest import AudioTestAssets
from ....utils import RemoteOpenAIServer
......
......@@ -9,7 +9,7 @@ import torch
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config.model import RunnerOption
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .....conftest import HfRunner, VllmRunner
from ....registry import HF_EXAMPLE_MODELS
......@@ -33,7 +33,7 @@ def run_test(
auto_cls: type[_BaseAutoModelClass],
use_tokenizer_eos: bool,
comparator: Callable[..., None],
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None,
get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None,
stop_str: list[str] | None,
limit_mm_per_prompt: dict[str, int],
vllm_runner_kwargs: dict[str, Any] | None,
......
......@@ -14,7 +14,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config.model import RunnerOption
from vllm.logprobs import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .....conftest import (
AUDIO_ASSETS,
......@@ -126,7 +126,7 @@ class VLMTestInfo(NamedTuple):
vllm_runner_kwargs: dict[str, Any] | None = None
# Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None = None
get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: list[str] | None = None
......
......@@ -22,8 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens,
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from contextlib import nullcontext
from typing import cast
......@@ -23,7 +24,7 @@ from vllm.multimodal.processing import (
replace_token_matches,
)
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from .utils import random_image
......@@ -238,7 +239,7 @@ def test_find_token_matches(
update_type,
):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = {
key: update_type(key, target, []).resolve(0)
......@@ -385,7 +386,7 @@ def test_find_text_matches(
update_type,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = {
key: update_type(key, target, []).resolve(0)
......@@ -545,7 +546,7 @@ def test_find_update_text(
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
for (
update_type,
......@@ -750,7 +751,7 @@ def test_find_update_tokens(
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
for (
update_type,
......@@ -900,7 +901,7 @@ def test_find_mm_placeholders(
update_type,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
......@@ -1029,7 +1030,7 @@ def test_hf_processor_init_kwargs(
expected_kwargs,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
......@@ -1065,7 +1066,7 @@ def test_hf_processor_call_kwargs(
expected_kwargs,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
......@@ -1088,9 +1089,7 @@ def test_apply_matches_no_match_exits_quickly():
With the fix, it should exit immediately when no match is found.
"""
import time
mock_tokenizer = cast(AnyTokenizer, object())
mock_tokenizer = cast(TokenizerLike, object())
# Create a long prompt with no placeholder
long_prompt = "x" * 10000
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment