"benchmarks/profiler/vscode:/vscode.git/clone" did not exist on "6e568d45523fb2e5658b107233583b23931e7a2b"
Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
...@@ -316,7 +316,7 @@ steps: ...@@ -316,7 +316,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
- tests/tokenization - tests/tokenizers_
- tests/test_sequence - tests/test_sequence
- tests/test_config - tests/test_config
- tests/test_logger - tests/test_logger
...@@ -324,7 +324,7 @@ steps: ...@@ -324,7 +324,7 @@ steps:
commands: commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately # OOM in the CI unless we run this separately
- pytest -v -s tokenization - pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min - label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45
......
...@@ -282,7 +282,7 @@ steps: ...@@ -282,7 +282,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/engine - tests/engine
- tests/tokenization - tests/tokenizers_
- tests/test_sequence - tests/test_sequence
- tests/test_config - tests/test_config
- tests/test_logger - tests/test_logger
...@@ -290,7 +290,7 @@ steps: ...@@ -290,7 +290,7 @@ steps:
commands: commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately # OOM in the CI unless we run this separately
- pytest -v -s tokenization - pytest -v -s tokenizers_
- label: V1 Test e2e + engine # 30min - label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45 timeout_in_minutes: 45
......
...@@ -620,7 +620,7 @@ def get_tokenizer( ...@@ -620,7 +620,7 @@ def get_tokenizer(
kwargs["use_fast"] = False kwargs["use_fast"] = False
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
try: try:
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.tokenizers import MistralTokenizer
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"MistralTokenizer requires vllm package.\n" "MistralTokenizer requires vllm package.\n"
......
...@@ -216,14 +216,13 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso ...@@ -216,14 +216,13 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso
# import the required packages # import the required packages
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
DeltaMessage)
# define a reasoning parser and register it to vllm # define a reasoning parser and register it to vllm
# the name list in register_module can be used # the name list in register_module can be used
# in --reasoning-parser. # in --reasoning-parser.
class ExampleParser(ReasoningParser): class ExampleParser(ReasoningParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
def extract_reasoning_streaming( def extract_reasoning_streaming(
......
...@@ -422,7 +422,7 @@ Here is a summary of a plugin file: ...@@ -422,7 +422,7 @@ Here is a summary of a plugin file:
# in --tool-call-parser. you can define as many # in --tool-call-parser. you can define as many
# tool parsers as you want here. # tool parsers as you want here.
class ExampleToolParser(ToolParser): class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens # adjust request. e.g.: set skip special tokens
......
...@@ -10,7 +10,7 @@ import pytest ...@@ -10,7 +10,7 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.tokenizers import MistralTokenizer
@pytest.fixture() @pytest.fixture()
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def default_tokenizer() -> AnyTokenizer: def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2") return AutoTokenizer.from_pretrained("gpt2")
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
...@@ -270,14 +270,14 @@ async def test_streaming_product_tool_call(): ...@@ -270,14 +270,14 @@ async def test_streaming_product_tool_call():
@pytest.fixture @pytest.fixture
def qwen_tokenizer() -> AnyTokenizer: def qwen_tokenizer() -> TokenizerLike:
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B") return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture @pytest.fixture
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser: def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
return Hermes2ProToolParser(qwen_tokenizer) return Hermes2ProToolParser(qwen_tokenizer)
...@@ -291,7 +291,7 @@ def any_chat_request() -> ChatCompletionRequest: ...@@ -291,7 +291,7 @@ def any_chat_request() -> ChatCompletionRequest:
def test_hermes_parser_streaming_just_forward_text( def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:
...@@ -323,7 +323,7 @@ def test_hermes_parser_streaming_just_forward_text( ...@@ -323,7 +323,7 @@ def test_hermes_parser_streaming_just_forward_text(
def test_hermes_parser_streaming_failure_case_bug_19056( def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:
...@@ -357,7 +357,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056( ...@@ -357,7 +357,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
def test_hermes_parser_streaming( def test_hermes_parser_streaming(
qwen_tokenizer: AnyTokenizer, qwen_tokenizer: TokenizerLike,
hermes_parser: Hermes2ProToolParser, hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest, any_chat_request: ChatCompletionRequest,
) -> None: ) -> None:
......
...@@ -7,11 +7,11 @@ import pytest ...@@ -7,11 +7,11 @@ import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
@pytest.fixture @pytest.fixture
def parser(default_tokenizer: AnyTokenizer): def parser(default_tokenizer: TokenizerLike):
return Llama3JsonToolParser(default_tokenizer) return Llama3JsonToolParser(default_tokenizer)
......
...@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import ( ...@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# Test cases similar to pythonic parser but with Llama4 specific format # Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
...@@ -64,7 +64,7 @@ PYTHON_TAG_FUNCTION_OUTPUT = ( ...@@ -64,7 +64,7 @@ PYTHON_TAG_FUNCTION_OUTPUT = (
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
) )
...@@ -208,7 +208,7 @@ def test_tool_call( ...@@ -208,7 +208,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
...@@ -224,7 +224,7 @@ def test_tool_call( ...@@ -224,7 +224,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
) )
...@@ -246,7 +246,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): ...@@ -246,7 +246,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer default_tokenizer
......
...@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import ( ...@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
...@@ -69,7 +69,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall( ...@@ -69,7 +69,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
) )
...@@ -188,7 +188,7 @@ def test_tool_call( ...@@ -188,7 +188,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
...@@ -205,7 +205,7 @@ def test_tool_call( ...@@ -205,7 +205,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
) )
...@@ -228,7 +228,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): ...@@ -228,7 +228,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer default_tokenizer
......
...@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import ( ...@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
) )
from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
...@@ -61,7 +61,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall( ...@@ -61,7 +61,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer): def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
) )
...@@ -168,7 +168,7 @@ def test_tool_call( ...@@ -168,7 +168,7 @@ def test_tool_call(
streaming: bool, streaming: bool,
model_output: str, model_output: str,
expected_tool_calls: list[FunctionCall], expected_tool_calls: list[FunctionCall],
default_tokenizer: AnyTokenizer, default_tokenizer: TokenizerLike,
): ):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
...@@ -185,7 +185,7 @@ def test_tool_call( ...@@ -185,7 +185,7 @@ def test_tool_call(
assert actual.function == expected assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
) )
...@@ -208,7 +208,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer): ...@@ -208,7 +208,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
@pytest.mark.parametrize("streaming", [False]) @pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer): def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully""" """test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer default_tokenizer
......
...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
class StreamingToolReconstructor: class StreamingToolReconstructor:
...@@ -111,7 +111,7 @@ def run_tool_extraction_nonstreaming( ...@@ -111,7 +111,7 @@ def run_tool_extraction_nonstreaming(
return tool_parser.extract_tool_calls(model_output, request) return tool_parser.extract_tool_calls(model_output, request)
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]: def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
# Split a string into a series of deltas using the provided tokenizer. Each # Split a string into a series of deltas using the provided tokenizer. Each
# delta will be the string equivalent of a single token. # delta will be the string equivalent of a single token.
token_ids = tokenizer.encode(text, add_special_tokens=False) token_ids = tokenizer.encode(text, add_special_tokens=False)
......
...@@ -28,8 +28,8 @@ from vllm.multimodal.utils import ( ...@@ -28,8 +28,8 @@ from vllm.multimodal.utils import (
encode_image_base64, encode_image_base64,
encode_video_base64, encode_video_base64,
) )
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH from ..utils import VLLM_PATH
......
...@@ -10,7 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( ...@@ -10,7 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolParser, MistralToolParser,
) )
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.tokenizers import MistralTokenizer
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
......
...@@ -9,7 +9,7 @@ from mistral_common.audio import Audio ...@@ -9,7 +9,7 @@ from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.messages import UserMessage
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.tokenizers import MistralTokenizer
from ....conftest import AudioTestAssets from ....conftest import AudioTestAssets
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .....conftest import HfRunner, VllmRunner from .....conftest import HfRunner, VllmRunner
from ....registry import HF_EXAMPLE_MODELS from ....registry import HF_EXAMPLE_MODELS
...@@ -33,7 +33,7 @@ def run_test( ...@@ -33,7 +33,7 @@ def run_test(
auto_cls: type[_BaseAutoModelClass], auto_cls: type[_BaseAutoModelClass],
use_tokenizer_eos: bool, use_tokenizer_eos: bool,
comparator: Callable[..., None], comparator: Callable[..., None],
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None, get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None,
stop_str: list[str] | None, stop_str: list[str] | None,
limit_mm_per_prompt: dict[str, int], limit_mm_per_prompt: dict[str, int],
vllm_runner_kwargs: dict[str, Any] | None, vllm_runner_kwargs: dict[str, Any] | None,
......
...@@ -14,7 +14,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass ...@@ -14,7 +14,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .....conftest import ( from .....conftest import (
AUDIO_ASSETS, AUDIO_ASSETS,
...@@ -126,7 +126,7 @@ class VLMTestInfo(NamedTuple): ...@@ -126,7 +126,7 @@ class VLMTestInfo(NamedTuple):
vllm_runner_kwargs: dict[str, Any] | None = None vllm_runner_kwargs: dict[str, Any] | None = None
# Optional callable which gets a list of token IDs from the model tokenizer # Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None = None get_stop_token_ids: Callable[[TokenizerLike], list[int]] | None = None
# Optional list of strings to stop generation, useful when stop tokens are # Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer # not special tokens in the tokenizer
stop_str: list[str] | None = None stop_str: list[str] | None = None
......
...@@ -22,8 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict ...@@ -22,8 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import ( from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config, cached_tokenizer_from_config,
encode_tokens, encode_tokens,
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from contextlib import nullcontext from contextlib import nullcontext
from typing import cast from typing import cast
...@@ -23,7 +24,7 @@ from vllm.multimodal.processing import ( ...@@ -23,7 +24,7 @@ from vllm.multimodal.processing import (
replace_token_matches, replace_token_matches,
) )
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .utils import random_image from .utils import random_image
...@@ -238,7 +239,7 @@ def test_find_token_matches( ...@@ -238,7 +239,7 @@ def test_find_token_matches(
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to token IDs # Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = { prompt_updates = {
key: update_type(key, target, []).resolve(0) key: update_type(key, target, []).resolve(0)
...@@ -385,7 +386,7 @@ def test_find_text_matches( ...@@ -385,7 +386,7 @@ def test_find_text_matches(
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
prompt_updates = { prompt_updates = {
key: update_type(key, target, []).resolve(0) key: update_type(key, target, []).resolve(0)
...@@ -545,7 +546,7 @@ def test_find_update_text( ...@@ -545,7 +546,7 @@ def test_find_update_text(
expected_by_update_type_mm_count, expected_by_update_type_mm_count,
): ):
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
for ( for (
update_type, update_type,
...@@ -750,7 +751,7 @@ def test_find_update_tokens( ...@@ -750,7 +751,7 @@ def test_find_update_tokens(
expected_by_update_type_mm_count, expected_by_update_type_mm_count,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
for ( for (
update_type, update_type,
...@@ -900,7 +901,7 @@ def test_find_mm_placeholders( ...@@ -900,7 +901,7 @@ def test_find_mm_placeholders(
update_type, update_type,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
mm_prompt_updates = { mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
...@@ -1029,7 +1030,7 @@ def test_hf_processor_init_kwargs( ...@@ -1029,7 +1030,7 @@ def test_hf_processor_init_kwargs(
expected_kwargs, expected_kwargs,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext( ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs), model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
...@@ -1065,7 +1066,7 @@ def test_hf_processor_call_kwargs( ...@@ -1065,7 +1066,7 @@ def test_hf_processor_call_kwargs(
expected_kwargs, expected_kwargs,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(TokenizerLike, object())
ctx = InputProcessingContext( ctx = InputProcessingContext(
model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs), model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
...@@ -1088,9 +1089,7 @@ def test_apply_matches_no_match_exits_quickly(): ...@@ -1088,9 +1089,7 @@ def test_apply_matches_no_match_exits_quickly():
With the fix, it should exit immediately when no match is found. With the fix, it should exit immediately when no match is found.
""" """
import time mock_tokenizer = cast(TokenizerLike, object())
mock_tokenizer = cast(AnyTokenizer, object())
# Create a long prompt with no placeholder # Create a long prompt with no placeholder
long_prompt = "x" * 10000 long_prompt = "x" * 10000
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment