"examples/cli/vllm_inc.py" did not exist on "c7bb1e83468122b2aee74e85778f8c8d84eb5a2a"
Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
......@@ -5,7 +5,7 @@ import pytest
from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
parser_name = "mistral"
......
......@@ -4,7 +4,7 @@
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.reasoning import ReasoningParser
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
class StreamingReasoningReconstructor:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.transformers_utils.tokenizer import get_tokenizer
TOKENIZER_NAMES = ["BAAI/bge-base-en"]
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
@pytest.mark.parametrize("n_tokens", [510])
def test_special_tokens(tokenizer_name: str, n_tokens: int):
tokenizer = get_tokenizer(tokenizer_name, revision="main")
prompts = "[UNK]" * n_tokens
prompt_token_ids = tokenizer.encode(prompts)
assert len(prompt_token_ids) == n_tokens + 2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
{meth}`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
def test_get_llama3_eos_token():
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128008, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# NOTE: Since CI runs the tests from the `tests` directory, it is necessary to rename
# this module to avoid conflicting with HF's `tokenizers` package
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import _get_protocol_attrs # type: ignore
import pytest
from transformers import PreTrainedTokenizerBase
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import get_tokenizer
TOKENIZER_NAMES = [
"facebook/opt-125m",
"gpt2",
]
def _get_missing_attrs(obj: object, target: type):
return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)]
@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
def test_tokenizer_like_protocol():
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer("gpt2", use_fast=False),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer("gpt2", use_fast=True),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
assert not (
missing_attrs := _get_missing_attrs(
get_tokenizer(
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
),
TokenizerLike,
)
), f"Missing attrs: {missing_attrs}"
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
def test_tokenizer_revision(tokenizer_name: str):
# Assume that "main" branch always exists
tokenizer = get_tokenizer(tokenizer_name, revision="main")
......@@ -21,3 +47,13 @@ def test_tokenizer_revision(tokenizer_name: str):
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match="not a valid git identifier"):
get_tokenizer(tokenizer_name, revision="never")
@pytest.mark.parametrize("tokenizer_name", ["BAAI/bge-base-en"])
@pytest.mark.parametrize("n_tokens", [510])
def test_special_tokens(tokenizer_name: str, n_tokens: int):
tokenizer = get_tokenizer(tokenizer_name, revision="main")
prompts = "[UNK]" * n_tokens
prompt_token_ids = tokenizer.encode(prompts)
assert len(prompt_token_ids) == n_tokens + 2
......@@ -6,7 +6,8 @@ from copy import deepcopy
import pytest
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
......@@ -25,7 +26,7 @@ def test_cached_tokenizer(model_id: str):
_check_consistency(unpickled_tokenizer, reference_tokenizer)
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
def _check_consistency(target: TokenizerLike, expected: TokenizerLike):
assert isinstance(target, type(expected))
# Cached attributes
......
......@@ -8,7 +8,7 @@ import pytest
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (
FastIncrementalDetokenizer,
......
......@@ -7,7 +7,7 @@ import pytest
from mistral_common.exceptions import InvalidMessageStructureException
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.transformers_utils.tokenizers.mistral import (
from vllm.tokenizers.mistral import (
MistralTokenizer,
_prepare_apply_chat_template_tools_and_messages,
)
......@@ -308,25 +308,6 @@ class TestMistralTokenizer:
def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer):
assert mistral_tokenizer.get_added_vocab() == {}
def test_encode_one(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686]
)
assert mistral_tokenizer.encode_one("Hello world !") == token_ids
assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids
assert (
mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1)
== token_ids[:-2]
)
assert (
mistral_tokenizer.encode_one(
"Hello world !", truncation=False, max_length=1
)
== token_ids
)
assert mistral_tokenizer.encode_one("") == []
def test_encode(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[1, 22177, 4304, 2662]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.tokenizers import TokenizerLike, TokenizerRegistry
from vllm.transformers_utils.tokenizer import get_tokenizer
class TestTokenizer(TokenizerLike):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer() # type: ignore
@property
def bos_token_id(self) -> int:
return 0
@property
def eos_token_id(self) -> int:
return 1
def test_customized_tokenizer():
TokenizerRegistry.register(
"test_tokenizer",
__name__,
TestTokenizer.__name__,
)
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
......@@ -14,8 +14,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available
MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking"
......@@ -173,7 +174,7 @@ def test_extract_tool_calls(
def stream_delta_message_generator(
ernie45_tool_parser: Ernie45ToolParser,
ernie45_tokenizer: AnyTokenizer,
ernie45_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
......
......@@ -10,8 +10,9 @@ from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
......@@ -44,7 +45,9 @@ def assert_tool_calls(
def stream_delta_message_generator(
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str
jamba_tool_parser: JambaToolParser,
jamba_tokenizer: TokenizerLike,
model_output: str,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)
......
......@@ -17,8 +17,9 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser,
)
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
......@@ -104,7 +105,7 @@ def assert_tool_calls(
def stream_delta_message_generator(
qwen3_tool_parser,
qwen3_tokenizer: AnyTokenizer,
qwen3_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
......
......@@ -15,8 +15,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
......@@ -256,7 +257,7 @@ def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
def stream_delta_message_generator(
seed_oss_tool_parser: SeedOssToolParser,
seed_oss_tokenizer: AnyTokenizer,
seed_oss_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
......
......@@ -13,8 +13,9 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
pytestmark = pytest.mark.cpu_test
......@@ -49,7 +50,7 @@ def assert_tool_calls(
def stream_delta_message_generator(
xlam_tool_parser: xLAMToolParser,
xlam_tokenizer: AnyTokenizer,
xlam_tokenizer: TokenizerLike,
model_output: str,
request: ChatCompletionRequest | None = None,
) -> Generator[DeltaMessage, None, None]:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, call, patch
import pytest
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
@pytest.mark.parametrize(
"allow_patterns,expected_relative_files",
[
(
["*.json", "correct*.txt"],
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
),
],
)
def test_list_filtered_repo_files(
allow_patterns: list[str], expected_relative_files: list[str]
):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "json_file.json").touch()
(path_tmp_dir / "correct_2.txt").touch()
(path_tmp_dir / "uncorrect.txt").touch()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
(subfolder / "uncorrect_sub.txt").touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
out_files = sorted(
list_filtered_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
)
assert out_files == sorted(expected_relative_files)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)
def test_get_llama3_eos_token():
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128008, 128009]
def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, call, patch
import pytest
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
@pytest.mark.parametrize(
"allow_patterns,expected_relative_files",
[
(
["*.json", "correct*.txt"],
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
),
],
)
def test_list_filtered_repo_files(
allow_patterns: list[str], expected_relative_files: list[str]
):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "json_file.json").touch()
(path_tmp_dir / "correct_2.txt").touch()
(path_tmp_dir / "uncorrect.txt").touch()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
(subfolder / "uncorrect_sub.txt").touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
out_files = sorted(
list_filtered_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
)
assert out_files == sorted(expected_relative_files)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)
......@@ -18,7 +18,7 @@ from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import (
EngineCoreEvent,
EngineCoreEventType,
......@@ -31,7 +31,7 @@ from vllm.v1.metrics.stats import IterationStats, SchedulerStats
def _ref_convert_id_to_token(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
token_id: int,
) -> str:
"""Reference impl of logprobs detokenization.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment