Unverified Commit 6c47f6bf authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Core] Remove tokenizer group in vLLM (#24078)


Signed-off-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent c15309a7
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest import pytest
from transformers import PreTrainedTokenizer
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.inputs import token_inputs from vllm.inputs import token_inputs
...@@ -54,10 +51,7 @@ def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, ...@@ -54,10 +51,7 @@ def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
- When the EOS token should be ignored, and the sequence continues - When the EOS token should be ignored, and the sequence continues
""" """
tokenizer = MagicMock(spec=PreTrainedTokenizer) stop_checker = StopChecker(max_model_len=1024)
get_tokenizer_for_seq = MagicMock(return_value=tokenizer)
stop_checker = StopChecker(max_model_len=1024,
get_tokenizer_for_seq=get_tokenizer_for_seq)
seq = sequence_with_eos( seq = sequence_with_eos(
text=text_wo_eos, text=text_wo_eos,
......
...@@ -58,16 +58,13 @@ def deepseek_r1_qwen_tokenizer(): ...@@ -58,16 +58,13 @@ def deepseek_r1_qwen_tokenizer():
@pytest.fixture @pytest.fixture
def stop_checker(): def stop_checker():
return StopChecker(max_model_len=10, return StopChecker(max_model_len=10)
get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer)
@pytest.fixture @pytest.fixture
def stop_checker_with_reasoner(): def stop_checker_with_reasoner():
reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer) reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer)
return StopChecker(max_model_len=10, return StopChecker(max_model_len=10, reasoner=reasoner)
get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer,
reasoner=reasoner)
def test_eos_token_stopping(stop_checker): def test_eos_token_stopping(stop_checker):
......
...@@ -208,25 +208,3 @@ def zephyr_lora_files(): ...@@ -208,25 +208,3 @@ def zephyr_lora_files():
"""Download zephyr LoRA files once per test session.""" """Download zephyr LoRA files once per test session."""
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora") return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
@pytest.fixture(scope="session")
def zephyr_lora_added_tokens_files(zephyr_lora_files):
"""Create zephyr LoRA files with added tokens once per test session."""
import shutil
from tempfile import TemporaryDirectory
from transformers import AutoTokenizer
tmp_dir = TemporaryDirectory()
tmp_model_dir = f"{tmp_dir.name}/zephyr"
shutil.copytree(zephyr_lora_files, tmp_model_dir)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
special_tokens=True)
assert added == 3
tokenizer.save_pretrained(tmp_model_dir)
yield tmp_model_dir
tmp_dir.cleanup()
...@@ -29,11 +29,7 @@ def monkeypatch_module(): ...@@ -29,11 +29,7 @@ def monkeypatch_module():
@pytest.fixture(scope="module", params=[False, True]) @pytest.fixture(scope="module", params=[False, True])
def server( def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811
request,
monkeypatch_module,
zephyr_lora_files, #noqa: F811
zephyr_lora_added_tokens_files): # noqa: F811
use_v1 = request.param use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
...@@ -49,7 +45,6 @@ def server( ...@@ -49,7 +45,6 @@ def server(
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"zephyr-lora={zephyr_lora_files}", f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
...@@ -79,7 +74,7 @@ async def client(server): ...@@ -79,7 +74,7 @@ async def client(server):
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
messages = [{ messages = [{
......
...@@ -27,7 +27,7 @@ GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"] ...@@ -27,7 +27,7 @@ GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files): def default_server_args(zephyr_lora_files):
return [ return [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -41,7 +41,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files): ...@@ -41,7 +41,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"zephyr-lora={zephyr_lora_files}", f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
...@@ -87,7 +86,7 @@ async def client(server): ...@@ -87,7 +86,7 @@ async def client(server):
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name, completion = await client.completions.create(model=model_name,
...@@ -115,20 +114,6 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): ...@@ -115,20 +114,6 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
assert completion.choices[0].prompt_logprobs is None assert completion.choices[0].prompt_logprobs is None
@pytest.mark.asyncio
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model="zephyr-lora2",
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should appear in tokenized prompt
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs # test using token IDs
...@@ -147,7 +132,7 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): ...@@ -147,7 +132,7 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs # test using token IDs
...@@ -713,7 +698,7 @@ async def test_guided_grammar(client: openai.AsyncOpenAI, ...@@ -713,7 +698,7 @@ async def test_guided_grammar(client: openai.AsyncOpenAI,
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"], [MODEL_NAME, "zephyr-lora"],
) )
@pytest.mark.parametrize("logprobs_arg", [1, 0]) @pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(client: openai.AsyncOpenAI, async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
......
...@@ -21,10 +21,7 @@ CONFIG = AutoConfig.from_pretrained(MODEL_NAME) ...@@ -21,10 +21,7 @@ CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args( def default_server_args() -> list[str]:
zephyr_lora_files,
zephyr_lora_added_tokens_files,
) -> list[str]:
return [ return [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
......
...@@ -67,12 +67,6 @@ def server_with_lora_modules_json(request, monkeypatch_module, ...@@ -67,12 +67,6 @@ def server_with_lora_modules_json(request, monkeypatch_module,
"base_model_name": MODEL_NAME "base_model_name": MODEL_NAME
} }
lora_module_2 = {
"name": "zephyr-lora2",
"path": zephyr_lora_files,
"base_model_name": MODEL_NAME
}
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -84,7 +78,6 @@ def server_with_lora_modules_json(request, monkeypatch_module, ...@@ -84,7 +78,6 @@ def server_with_lora_modules_json(request, monkeypatch_module,
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
json.dumps(lora_module_1), json.dumps(lora_module_1),
json.dumps(lora_module_2),
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
...@@ -121,7 +114,6 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI, ...@@ -121,7 +114,6 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI,
for lora_model in lora_models) for lora_model in lora_models)
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora" assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -209,7 +201,7 @@ async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, ...@@ -209,7 +201,7 @@ async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path, async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path,
zephyr_lora_files): zephyr_lora_files):
"""Validate that many loras can be dynamically registered and inferenced """Validate that many loras can be dynamically registered and inferenced
with concurrently""" with concurrently"""
# This test file configures the server with --max-cpu-loras=2 and this test # This test file configures the server with --max-cpu-loras=2 and this test
......
...@@ -26,7 +26,6 @@ def server(zephyr_lora_files): ...@@ -26,7 +26,6 @@ def server(zephyr_lora_files):
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"zephyr-lora={zephyr_lora_files}", f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
"--max-lora-rank", "--max-lora-rank",
"64", "64",
"--max-cpu-loras", "--max-cpu-loras",
...@@ -56,4 +55,3 @@ async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): ...@@ -56,4 +55,3 @@ async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
assert all(lora_model.root == zephyr_lora_files assert all(lora_model.root == zephyr_lora_files
for lora_model in lora_models) for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora" assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
...@@ -14,7 +14,7 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" ...@@ -14,7 +14,7 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_added_tokens_files: str): # noqa: F811 def server():
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -24,12 +24,6 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 ...@@ -24,12 +24,6 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
"--enforce-eager", "--enforce-eager",
"--max-num-seqs", "--max-num-seqs",
"128", "128",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
"--enable-tokenizer-info-endpoint", "--enable-tokenizer-info-endpoint",
] ]
...@@ -38,10 +32,8 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 ...@@ -38,10 +32,8 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def tokenizer_name(model_name: str, def tokenizer_name(model_name: str):
zephyr_lora_added_tokens_files: str): # noqa: F811 return model_name
return zephyr_lora_added_tokens_files if (
model_name == "zephyr-lora2") else model_name
@pytest_asyncio.fixture @pytest_asyncio.fixture
...@@ -53,7 +45,7 @@ async def client(server): ...@@ -53,7 +45,7 @@ async def client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name,tokenizer_name", "model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME)],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_tokenize_completions( async def test_tokenize_completions(
...@@ -86,7 +78,7 @@ async def test_tokenize_completions( ...@@ -86,7 +78,7 @@ async def test_tokenize_completions(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name,tokenizer_name", "model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME)],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_tokenize_chat( async def test_tokenize_chat(
...@@ -148,7 +140,7 @@ async def test_tokenize_chat( ...@@ -148,7 +140,7 @@ async def test_tokenize_chat(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name,tokenizer_name", "model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME)],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_tokenize_chat_with_tools( async def test_tokenize_chat_with_tools(
...@@ -225,7 +217,7 @@ async def test_tokenize_chat_with_tools( ...@@ -225,7 +217,7 @@ async def test_tokenize_chat_with_tools(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, tokenizer_name", "model_name, tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME)],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_tokenize_with_return_token_strs( async def test_tokenize_with_return_token_strs(
...@@ -260,7 +252,7 @@ async def test_tokenize_with_return_token_strs( ...@@ -260,7 +252,7 @@ async def test_tokenize_with_return_token_strs(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name,tokenizer_name", "model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME)],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_detokenize( async def test_detokenize(
...@@ -287,7 +279,7 @@ async def test_detokenize( ...@@ -287,7 +279,7 @@ async def test_detokenize(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name,tokenizer_name", "model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME)],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_tokenizer_info_basic( async def test_tokenizer_info_basic(
...@@ -384,4 +376,4 @@ async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer): ...@@ -384,4 +376,4 @@ async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
if chat_template: if chat_template:
assert isinstance(chat_template, assert isinstance(chat_template,
str), ("Chat template should be a string") str), ("Chat template should be a string")
assert chat_template.strip(), "Chat template should not be empty" assert chat_template.strip(), "Chat template should not be empty"
\ No newline at end of file
...@@ -18,6 +18,8 @@ SERVER_ARGS = [ ...@@ -18,6 +18,8 @@ SERVER_ARGS = [
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
f"{LORA_MODEL}={LORA_MODEL}", f"{LORA_MODEL}={LORA_MODEL}",
"--tokenizer",
f"{LORA_MODEL}",
] ]
TOOLS = [{ TOOLS = [{
......
...@@ -23,7 +23,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, ...@@ -23,7 +23,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
encode_video_base64) encode_video_base64)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
...@@ -69,12 +69,7 @@ def phi3v_model_config_mm_interleaved(): ...@@ -69,12 +69,7 @@ def phi3v_model_config_mm_interleaved():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def phi3v_tokenizer(): def phi3v_tokenizer():
return TokenizerGroup( return get_tokenizer(PHI3V_MODEL_ID)
tokenizer_id=PHI3V_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
...@@ -91,12 +86,7 @@ def qwen2_audio_model_config(): ...@@ -91,12 +86,7 @@ def qwen2_audio_model_config():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def qwen2_audio_tokenizer(): def qwen2_audio_tokenizer():
return TokenizerGroup( return get_tokenizer(QWEN2AUDIO_MODEL_ID)
tokenizer_id=QWEN2AUDIO_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
...@@ -115,12 +105,7 @@ def qwen25omni_model_config_mm_interleaved(): ...@@ -115,12 +105,7 @@ def qwen25omni_model_config_mm_interleaved():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def qwen25omni_tokenizer(): def qwen25omni_tokenizer():
return TokenizerGroup( return get_tokenizer(QWEN25OMNI_MODEL_ID)
tokenizer_id=QWEN25OMNI_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
...@@ -136,12 +121,7 @@ def mistral_model_config(): ...@@ -136,12 +121,7 @@ def mistral_model_config():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def mistral_tokenizer(): def mistral_tokenizer():
return TokenizerGroup( return get_tokenizer(MISTRAL_MODEL_ID)
tokenizer_id=MISTRAL_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -2250,15 +2230,11 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ...@@ -2250,15 +2230,11 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype) dtype=model_info.dtype)
# Build the tokenizer group and grab the underlying tokenizer # Build the tokenizer
tokenizer_group = TokenizerGroup( tokenizer = get_tokenizer(
model, model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer
tools = ([{ tools = ([{
"type": "function", "type": "function",
...@@ -2307,14 +2283,10 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -2307,14 +2283,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype) dtype=model_info.dtype)
tokenizer_group = TokenizerGroup( tokenizer = get_tokenizer(
model, model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template # Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template( chat_template = resolve_hf_chat_template(
...@@ -2368,14 +2340,10 @@ def test_resolve_content_format_fallbacks(model, expected_format): ...@@ -2368,14 +2340,10 @@ def test_resolve_content_format_fallbacks(model, expected_format):
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype) dtype=model_info.dtype)
tokenizer_group = TokenizerGroup( tokenizer = get_tokenizer(
model_config.tokenizer, model_config.tokenizer,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template # Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template( chat_template = resolve_hf_chat_template(
...@@ -2432,14 +2400,10 @@ def test_resolve_content_format_examples(template_path, expected_format): ...@@ -2432,14 +2400,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
trust_remote_code=True, trust_remote_code=True,
) )
tokenizer_group = TokenizerGroup( dummy_tokenizer = get_tokenizer(
PHI3V_MODEL_ID, # Dummy PHI3V_MODEL_ID, # Dummy
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
dummy_tokenizer = tokenizer_group.tokenizer
dummy_tokenizer.chat_template = None dummy_tokenizer.chat_template = None
chat_template = load_chat_template(EXAMPLES_DIR / template_path) chat_template = load_chat_template(EXAMPLES_DIR / template_path)
......
...@@ -13,14 +13,6 @@ from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test ...@@ -13,14 +13,6 @@ from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "meta-llama/Llama-2-7b-hf" MODEL_PATH = "meta-llama/Llama-2-7b-hf"
EXPECTED_NO_LORA_OUTPUT = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
]
EXPECTED_LORA_OUTPUT = [ EXPECTED_LORA_OUTPUT = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
...@@ -79,23 +71,12 @@ def generate_and_test(llm, ...@@ -79,23 +71,12 @@ def generate_and_test(llm,
sql_lora_files, sql_lora_files,
tensorizer_config_dict: Union[dict, None] = None): tensorizer_config_dict: Union[dict, None] = None):
print("lora adapter created") print("lora adapter created")
assert do_sample(llm,
sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1") print("lora 1")
assert do_sample(llm, assert do_sample(llm,
sql_lora_files, sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict, tensorizer_config_dict=tensorizer_config_dict,
lora_id=1) == EXPECTED_LORA_OUTPUT lora_id=1) == EXPECTED_LORA_OUTPUT
print("no lora")
assert do_sample(llm,
sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict,
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 2") print("lora 2")
assert do_sample(llm, assert do_sample(llm,
sql_lora_files, sql_lora_files,
...@@ -110,6 +91,7 @@ def test_llama_lora(sql_lora_files): ...@@ -110,6 +91,7 @@ def test_llama_lora(sql_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
tokenizer=sql_lora_files,
enable_lora=True, enable_lora=True,
# also test odd max_num_seqs # also test odd max_num_seqs
max_num_seqs=13, max_num_seqs=13,
...@@ -123,6 +105,7 @@ def test_llama_lora_tp4(sql_lora_files): ...@@ -123,6 +105,7 @@ def test_llama_lora_tp4(sql_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
tokenizer=sql_lora_files,
enable_lora=True, enable_lora=True,
max_num_seqs=16, max_num_seqs=16,
max_loras=4, max_loras=4,
...@@ -137,6 +120,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): ...@@ -137,6 +120,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
tokenizer=sql_lora_files,
enable_lora=True, enable_lora=True,
max_num_seqs=16, max_num_seqs=16,
max_loras=4, max_loras=4,
...@@ -184,6 +168,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, ...@@ -184,6 +168,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
loaded_llm = LLM(model=model_ref, loaded_llm = LLM(model=model_ref,
tokenizer=sql_lora_files,
load_format="tensorizer", load_format="tensorizer",
enable_lora=True, enable_lora=True,
enforce_eager=True, enforce_eager=True,
...@@ -195,11 +180,6 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, ...@@ -195,11 +180,6 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
tc_as_dict = tensorizer_config.to_serializable() tc_as_dict = tensorizer_config.to_serializable()
print("lora adapter created") print("lora adapter created")
assert do_sample(loaded_llm,
sql_lora_files,
tensorizer_config_dict=tc_as_dict,
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1") print("lora 1")
assert do_sample(loaded_llm, assert do_sample(loaded_llm,
sql_lora_files, sql_lora_files,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.v1.engine.processor import Processor
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
sql_lora_files):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that define additional tokens.
"""
# Set up a base model compatible with the sql_lora_files adapter and
# a known number of tokens in the base model.
model_config = ModelConfig(
model=llama_2_7b_base_huggingface_id,
tokenizer=llama_2_7b_base_huggingface_id,
tokenizer_mode="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
device_config=DeviceConfig(),
lora_config=LoRAConfig(),
)
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
processor = Processor(vllm_config, tokenizer)
lora_request = LoRARequest("1", 1, str(sql_lora_files))
request_id = "1"
prompt = "a prompt"
# tokens added in the lora adapter should not raise an error
lora_token_ids = [32000, 32001, 32002, 32003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=lora_token_ids),
lora_request=lora_request)
# tokens in the base model should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
lora_request=lora_request)
# tokens not in the lora adapter should raise an error
invalid_token_ids = [35000, 35001, 35002, 35003]
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=invalid_token_ids),
lora_request=lora_request)
# tokens in the lora adapter with no lora request should raise an error
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=lora_token_ids),
)
def test_allowed_token_ids_with_lora_adapter_no_vocab(
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that do not define additional tokens.
"""
# Set up a base model compatible with the qwen25vl_lora_files adapter and
# a known number of tokens in the base model.
model_config = ModelConfig(
model=qwen25vl_base_huggingface_id,
tokenizer=qwen25vl_base_huggingface_id,
tokenizer_mode="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
device_config=DeviceConfig(),
lora_config=LoRAConfig(),
)
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
processor = Processor(vllm_config, tokenizer)
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
request_id = "1"
prompt = "a prompt"
# tokens in the base model should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
lora_request=lora_request)
# tokens in the base model with no lora request should not raise an error
base_token_ids = [1000, 1001, 1002, 1003]
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=base_token_ids),
)
# tokens not in the base model should raise an error
invalid_token_ids = [200000, 200001, 200002, 200003]
with pytest.raises(ValueError):
processor.process_inputs(
request_id,
prompt,
params=SamplingParams(allowed_token_ids=invalid_token_ids),
lora_request=lora_request)
...@@ -82,31 +82,20 @@ def test_quant_model_lora(tinyllama_lora_files, model): ...@@ -82,31 +82,20 @@ def test_quant_model_lora(tinyllama_lora_files, model):
gpu_memory_utilization=0.2, #avoid OOM gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization, quantization=model.quantization,
trust_remote_code=True, trust_remote_code=True,
enable_chunked_prefill=True) enable_chunked_prefill=True,
tokenizer=tinyllama_lora_files)
if model.quantization is None: if model.quantization is None:
expected_no_lora_output = [
"Here are some examples of orange-brown colors",
"I'm sorry, I don't have"
]
expected_lora_output = [ expected_lora_output = [
"#ff8050", "#ff8050",
"#ff8080", "#ff8080",
] ]
elif model.quantization == "awq": elif model.quantization == "awq":
expected_no_lora_output = [
"I'm sorry, I don't understand",
"I'm sorry, I don't understand",
]
expected_lora_output = [ expected_lora_output = [
"#f07700: A v", "#f07700: A v",
"#f00000: A v", "#f00000: A v",
] ]
elif model.quantization == "gptq": elif model.quantization == "gptq":
expected_no_lora_output = [
"I'm sorry, I don't have",
"I'm sorry, I don't have",
]
expected_lora_output = [ expected_lora_output = [
"#f08800: This is", "#f08800: This is",
"#f07788 \n#", "#f07788 \n#",
...@@ -117,7 +106,6 @@ def test_quant_model_lora(tinyllama_lora_files, model): ...@@ -117,7 +106,6 @@ def test_quant_model_lora(tinyllama_lora_files, model):
# Assert that the outputs changed. # Assert that the outputs changed.
if (model.quantization == "gptq" if (model.quantization == "gptq"
and expected_output is expected_lora_output): and expected_output is expected_lora_output):
assert output != expected_no_lora_output
for i, o in enumerate(output): for i, o in enumerate(output):
assert o.startswith( assert o.startswith(
'#'), f"Expected example {i} to start with # but got {o}" '#'), f"Expected example {i} to start with # but got {o}"
...@@ -127,12 +115,6 @@ def test_quant_model_lora(tinyllama_lora_files, model): ...@@ -127,12 +115,6 @@ def test_quant_model_lora(tinyllama_lora_files, model):
max_tokens = 10 max_tokens = 10
print("lora adapter created") print("lora adapter created")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)
print("lora 1") print("lora 1")
output = do_sample(llm, output = do_sample(llm,
tinyllama_lora_files, tinyllama_lora_files,
...@@ -140,13 +122,6 @@ def test_quant_model_lora(tinyllama_lora_files, model): ...@@ -140,13 +122,6 @@ def test_quant_model_lora(tinyllama_lora_files, model):
max_tokens=max_tokens) max_tokens=max_tokens)
expect_match(output, expected_lora_output) expect_match(output, expected_lora_output)
print("no lora")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)
print("lora 2") print("lora 2")
output = do_sample(llm, output = do_sample(llm,
tinyllama_lora_files, tinyllama_lora_files,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer_group = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_loras=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
prompt="prompt", lora_request=lora_request)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)
assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
lora_request) != tokenizer_group.get_lora_tokenizer(None)
assert tokenizer_group.get_lora_tokenizer(
lora_request) == await tokenizer_group.get_lora_tokenizer_async(
lora_request)
def test_get_lora_tokenizer(sql_lora_files, tmp_path):
lora_request = None
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
lora_request = LoRARequest("1", 1, sql_lora_files)
tokenizer = get_lora_tokenizer(lora_request)
assert tokenizer.get_added_vocab()
lora_request = LoRARequest("1", 1, str(tmp_path))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
@pytest.mark.parametrize("enable_lora", [True, False])
@pytest.mark.parametrize("max_num_seqs", [1, 2])
@pytest.mark.parametrize("max_loras", [1, 2])
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
tokenizer_group = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,
max_loras=max_loras,
max_input_length=None,
)
if enable_lora:
assert tokenizer_group.lora_tokenizers.capacity == max(
max_num_seqs, max_loras)
else:
assert tokenizer_group.lora_tokenizers.capacity == 0
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
from vllm.inputs import token_inputs from vllm.inputs import token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import Sequence from vllm.sequence import Sequence
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer import get_tokenizer
# Make two prefixes with different first blocks. # Make two prefixes with different first blocks.
prefix_start = [("You are an expert"), ("You are a")] prefix_start = [("You are an expert"), ("You are a")]
...@@ -47,12 +47,7 @@ def flatten_2d(li): ...@@ -47,12 +47,7 @@ def flatten_2d(li):
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
concurrent_lora_int_ids: list[Optional[int]]): concurrent_lora_int_ids: list[Optional[int]]):
tokenizer = TokenizerGroup( tokenizer = get_tokenizer("facebook/opt-125m")
tokenizer_id="facebook/opt-125m",
enable_lora=False,
max_num_seqs=max_num_seqs,
max_input_length=None,
)
hashes: list[list[list[int]]] = [] hashes: list[list[list[int]]] = []
...@@ -76,7 +71,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, ...@@ -76,7 +71,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
inputs=token_inputs(prompt_token_ids, inputs=token_inputs(prompt_token_ids,
prompt=prompt), prompt=prompt),
block_size=block_size, block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
lora_request=lora_request) lora_request=lora_request)
num_blocks = len(prompt_token_ids) // block_size num_blocks = len(prompt_token_ids) // block_size
......
...@@ -11,7 +11,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, ...@@ -11,7 +11,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.inputs import token_inputs from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
...@@ -221,17 +221,14 @@ def test_oov_decode(tokenizer, fast): ...@@ -221,17 +221,14 @@ def test_oov_decode(tokenizer, fast):
@pytest.fixture @pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer: def detokenizer(tokenizer_name: str) -> Detokenizer:
tokenizer_group = TokenizerGroup( tokenizer = get_tokenizer(
tokenizer_id=tokenizer_name, tokenizer_name,
enable_lora=False,
max_num_seqs=100,
max_input_length=None,
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
trust_remote_code=False, trust_remote_code=False,
revision=None, revision=None,
) )
return Detokenizer(tokenizer_group) return Detokenizer(tokenizer)
@pytest.fixture(name="complete_sequence_token_ids") @pytest.fixture(name="complete_sequence_token_ids")
...@@ -312,8 +309,7 @@ def test_decode_prompt_logprobs(complete_sequence: str, ...@@ -312,8 +309,7 @@ def test_decode_prompt_logprobs(complete_sequence: str,
# don't support that. # don't support that.
if complete_sequence not in SPECIAL_TOKS_TRUTH: if complete_sequence not in SPECIAL_TOKS_TRUTH:
skip_special_tokens = True skip_special_tokens = True
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None), elif not isinstance(detokenizer.tokenizer, MistralTokenizer):
MistralTokenizer):
skip_special_tokens = False skip_special_tokens = False
else: else:
pytest.skip("MistralTokenizers don't support " pytest.skip("MistralTokenizers don't support "
...@@ -339,7 +335,7 @@ def test_decode_prompt_logprobs(complete_sequence: str, ...@@ -339,7 +335,7 @@ def test_decode_prompt_logprobs(complete_sequence: str,
# decoded_prompt_logprobs doesn't contain the first token. # decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids token_ids = complete_sequence_token_ids
tokenizer = detokenizer.get_tokenizer_for_seq(seq) tokenizer = detokenizer.tokenizer
text_full = tokenizer.decode(token_ids, text_full = tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens)
text_first = tokenizer.decode(token_ids[0], text_first = tokenizer.decode(token_ids[0],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@pytest.mark.asyncio
async def test_tokenizer_group():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
prompt="prompt", lora_request=None)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(prompt="prompt",
lora_request=None)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)
...@@ -57,6 +57,10 @@ class TestTokenizer(TokenizerBase): ...@@ -57,6 +57,10 @@ class TestTokenizer(TokenizerBase):
def max_token_id(self) -> int: def max_token_id(self) -> int:
raise NotImplementedError() raise NotImplementedError()
@property
def truncation_side(self) -> str:
raise NotImplementedError()
def __call__( def __call__(
self, self,
text: Union[str, list[str], list[int]], text: Union[str, list[str], list[int]],
......
...@@ -12,7 +12,6 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, ...@@ -12,7 +12,6 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
generate_dummy_prompt_logprobs_tensors, generate_dummy_prompt_logprobs_tensors,
generate_dummy_sample_logprobs) generate_dummy_sample_logprobs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from ...distributed.conftest import publisher_config, random_port # noqa: F401 from ...distributed.conftest import publisher_config, random_port # noqa: F401
...@@ -24,7 +23,7 @@ EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor] ...@@ -24,7 +23,7 @@ EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor]
def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
"""Generate output processor dummy test vectors, without logprobs """Generate output processor dummy test vectors, without logprobs
Returns: Returns:
DummyOutputProcessorTestVectors instance with no logprobs DummyOutputProcessorTestVectors instance with no logprobs
""" """
...@@ -48,9 +47,6 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: ...@@ -48,9 +47,6 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
] ]
return DummyOutputProcessorTestVectors( return DummyOutputProcessorTestVectors(
tokenizer=tokenizer, tokenizer=tokenizer,
tokenizer_group=init_tokenizer_from_configs(
vllm_config.model_config, vllm_config.scheduler_config,
vllm_config.lora_config),
vllm_config=vllm_config, vllm_config=vllm_config,
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
...@@ -68,7 +64,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: ...@@ -68,7 +64,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
@pytest.fixture @pytest.fixture
def dummy_test_vectors() -> DummyOutputProcessorTestVectors: def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
"""Generate output processor dummy test vectors, with logprobs """Generate output processor dummy test vectors, with logprobs
Returns: Returns:
DummyOutputProcessorTestVectors instance with logprobs DummyOutputProcessorTestVectors instance with logprobs
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment