Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
......@@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs
def kimiv_vl_vllm_to_hf_output(
vllm_output: RunnerOutput,
model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]:
"""Sanitize vllm output [kimi_vl models] to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "<|im_end|>[EOS]"
return output_ids, hf_output_str, out_logprobs
def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput:
config = AutoConfig.from_pretrained(model)
......
......@@ -59,24 +59,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch):
def server_embedding():
# GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
with pytest.MonkeyPatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def server_generate():
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
with pytest.MonkeyPatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_embedding(monkeypatch: pytest.MonkeyPatch,
server_embedding: RemoteOpenAIServer):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
async with server_embedding.get_async_client() as async_client:
yield async_client
async def client_embedding(server_embedding: RemoteOpenAIServer):
async with server_embedding.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
......
......@@ -153,14 +153,24 @@ def test_matryoshka(
with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
matryoshka_dimensions = (
vllm_model.model.llm_engine.model_config.matryoshka_dimensions)
assert matryoshka_dimensions is not None
if dimensions not in matryoshka_dimensions:
with pytest.raises(ValueError):
vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
else:
vllm_outputs = vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
# SPDX-License-Identifier: Apache-2.0
"""Compare the embedding outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`.
"""
import pytest
from tests.models.embedding.utils import EmbedModelInfo
from ..utils import check_embeddings_close
EMBEDDING_PROMPTS = [
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
'Mexico City of Course!'
]
MODELS = [
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
is_matryoshka=False,
architecture="BertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
is_matryoshka=False,
architecture="NomicBertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
is_matryoshka=True,
architecture="BertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
is_matryoshka=True,
architecture="XLMRobertaModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
is_matryoshka=True,
architecture="GteModel",
enable_test=True),
]
@pytest.mark.parametrize("model_info", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model_info: EmbedModelInfo,
dtype: str,
monkeypatch,
) -> None:
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
example_prompts = example_prompts + EMBEDDING_PROMPTS
vllm_extra_kwargs = {
"hf_overrides": {
"is_matryoshka": model_info.is_matryoshka
}
}
with hf_runner(model_info.name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model_info.name,
task="embed",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
assert (vllm_model.model.llm_engine.model_config.is_matryoshka ==
model_info.is_matryoshka)
if model_info.architecture:
assert (model_info.architecture
in vllm_model.model.llm_engine.model_config.architectures)
vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
......@@ -37,3 +38,29 @@ def matryoshka_fy(tensor, dimensions):
tensor = tensor[..., :dimensions]
tensor = F.normalize(tensor, p=2, dim=1)
return tensor
class EmbedModelInfo(NamedTuple):
name: str
is_matryoshka: bool
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
enable_test: bool = True
def correctness_test(hf_model,
inputs,
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None):
hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
......@@ -15,12 +15,12 @@ from ...utils import check_logprobs_close
from ....utils import models_path_prefix
MODELS = [os.path.join(models_path_prefix, "microsoft/Florence-2-base")]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
TOKENIZER = os.path.join(models_path_prefix, "facebook/bart-base")
# Florence-2 model repo's tokenizer config is missing some special tokens.
# Therefore, we use a converted tokenizer from a forked repo
TOKENIZER = os.path.join(models_path_prefix, "Isotr0py/Florence-2-tokenizer")
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<CAPTION>", # special task token
"<OD>", # special task token which will output special tokens
"cherry_blossom":
"Describe in detail what is shown in the image.",
})
......@@ -47,7 +47,6 @@ def hf_to_vllm_output(hf_output: tuple[list[int], str,
output_ids, output_str, out_logprobs = hf_output
output_str = output_str.replace("</s>", "").replace("<s>", "")
output_ids = [ids for ids in output_ids if ids not in [0, 2]]
return output_ids, output_str, out_logprobs
......@@ -73,8 +72,11 @@ def run_test(
enforce_eager=True) as vllm_model:
vllm_outputs_per_case = [
vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs=num_logprobs)
for prompts in inputs
prompts,
max_tokens,
num_logprobs=num_logprobs,
skip_special_tokens=False,
) for prompts in inputs
]
hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
......@@ -95,6 +97,7 @@ def run_test(
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=1,
)
......
......@@ -254,10 +254,12 @@ def _test_processing_correctness_mistral(
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
"ibm-granite/granite-speech-3.3-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"moonshotai/Kimi-VL-A3B-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
......@@ -273,12 +275,14 @@ def _test_processing_correctness_mistral(
"nvidia/NVLM-D-72B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-7B",
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for phi4mm's multimodal preprocessing kwargs."""
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"])
# yapf: disable
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
[
({"dynamic_hd": 4}, 1329),
({"dynamic_hd": 16}, 4433),
# the default num_crops of phi-4-multimodal is 36
({}, 9585),
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
image_size = ctx.get_hf_config(
).embd_layer["image_embd_layer"]["crop_size"]
dummy_image_size = (image_size * 7, image_size * 7)
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count(
_IMAGE_PLACEHOLDER_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs
......@@ -121,9 +121,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"),
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
{"1b": "bigscience/bloomz-1b1"}),
"ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b",
trust_remote_code=True),
trust_remote_code=True,
max_transformers_version="4.48"),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501
trust_remote_code=True),
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
......@@ -141,24 +143,26 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it",
min_transformers_version="4.50"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo(
"THUDM/GLM-4-32B-0414",
is_available_online=False,
min_transformers_version="4.52.dev0"
),
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"),
"GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"),
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
{"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder",
{"tiny": "bigcode/tiny_starcoder_py"}), # noqa: E501
"GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M",
{"6b": "EleutherAI/gpt-j-6b"}),
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m",
{"1b": "EleutherAI/pythia-1.4b"}),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
trust_remote_code=True),
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
......@@ -186,7 +190,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501
{"falcon3": "ehristoforu/Falcon3-MoE-2x7B-Insruct"}), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
......@@ -194,7 +199,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("shanearora/OLMo-7B-1124-hf"),
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"),
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m",
{"1b": "facebook/opt-iml-max-1.3b"}),
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
......@@ -204,10 +210,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",
extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct",
extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-8B",
......@@ -233,8 +241,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False,
trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct",
min_transformers_version="4.49"),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
......@@ -245,11 +252,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
......@@ -273,6 +284,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
}
_MULTIMODAL_EXAMPLE_MODELS = {
......@@ -286,10 +298,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it",
min_transformers_version="4.50"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501
min_transformers_version="4.52.0"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
......@@ -302,6 +315,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
min_transformers_version="4.51"),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
......@@ -322,7 +338,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
trust_remote_code=True),
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
min_transformers_version="4.50", # noqa: E501
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
max_transformers_version="4.48",
......@@ -348,8 +363,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501
min_transformers_version="4.52"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
......@@ -358,7 +374,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="facebook/bart-base",
tokenizer="Isotr0py/Florence-2-tokenizer",
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
......@@ -379,6 +395,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
"Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
}
_TRANSFORMERS_MODELS = {
......
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of a GPTQ model to a bitblas model.
Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.
Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.
Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass
import pytest
from .utils import check_logprobs_close
@dataclass
class ModelPair:
model_bitblas: str
model_gptq: str
model_pairs = [
ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas",
model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_bitblas,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="bitblas",
)
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of a GPTQ model to a bitblas model.
Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.
Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.
Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass
import pytest
from .utils import check_logprobs_close
@dataclass
class ModelPair:
model_gptq: str
model_pairs = [
ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_gptq,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="gptq_bitblas",
)
......@@ -24,10 +24,7 @@ def test_can_initialize(model_arch):
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)
if hasattr(hf_config, "text_config"):
text_config: PretrainedConfig = hf_config.text_config
else:
text_config = hf_config
text_config = hf_config.get_text_config()
text_config.update({
"num_layers": 1,
......
......@@ -18,10 +18,9 @@ def test_plugin(
m.setenv("VLLM_USE_V1", "0")
m.setenv("VLLM_PLUGINS", "")
with pytest.raises(Exception) as excinfo:
match = "Cannot find model module"
with pytest.raises(ValueError, match=match):
LLM(model=dummy_opt_path, load_format="dummy")
error_msg = "has no vLLM implementation and the Transformers implementation is not compatible with vLLM" # noqa: E501
assert (error_msg in str(excinfo.value))
@create_new_process_for_each_test()
......
......@@ -266,16 +266,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
reason="WNA16 is not supported on ROCm.")
@pytest.mark.parametrize(
"wNa16_args",
[
(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8),
(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-group128-v2"), "group", 128, 8),
(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w8a16-per-channel"), "channel", None, 4),
],
[(os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8,
True, False),
(os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w4a16-group128-v2"), "group", 128, 8, True,
False),
(os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8a16-per-channel"), "channel", None, 4,
True, False),
(os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256"), "group", 128,
8, False, False),
(os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel"),
"channel", None, 8, False, False),
(os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder"),
"group", 128, 8, False, True)],
)
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="The tests are skipped on non-CUDA platform.")
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
with vllm_runner(model) as llm:
def check_model(model):
......@@ -291,6 +298,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
if group is None else group)
assert qkv_proj.scheme.pack_factor == pack_factor
assert qkv_proj.scheme.symmetric == symmetric
assert qkv_proj.scheme.has_g_idx == has_g_idx
llm.apply_model(check_model)
......
......@@ -7,6 +7,9 @@ Run `pytest tests/samplers/test_beam_search.py`.
import pytest
import os
from ..utils import models_path_prefix
from transformers import AutoModelForSeq2SeqLM
from vllm.assets.audio import AudioAsset
@pytest.fixture(autouse=True)
......@@ -21,7 +24,8 @@ def v1(run_with_both_engines):
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [64]
BEAM_WIDTHS = [4]
MODELS = [os.path.join(models_path_prefix, "TinyLlama/TinyLlama-1.1B-Chat-v1.0")]
MM_BEAM_WIDTHS = [2]
MODELS = [os.path.join(models_path_prefix, "TinyLlama/TinyLlama-1.1B-Chat-v1.0")]
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
......@@ -50,15 +54,90 @@ def test_beam_search_single_input(
for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for i, (hf_text,
for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{i}-th hf output:")
print(f">>>{j}-th hf output:")
print(hf_text)
print(f">>>{i}-th vllm output:")
print(f">>>{j}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
def test_beam_search_passes_multimodal_data(
hf_runner,
vllm_runner,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Ensure that beam search passes multimodal data through correctly."""
# NOTE - this test is primarily to check that mm data is passed to beams
# correctly. As such, we just need to check one extra modality to make
# sure things pass through properly.
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
model = "Qwen/Qwen2-Audio-7B-Instruct"
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
prompts = [
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501
]
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
audio_token_id = hf_model.config.audio_token_index
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
hf_outputs = hf_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)
seq_with_no_audio_toks = lambda seq: [
tok for tok in seq if tok != audio_token_id
]
for i in range(len(prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
print(hf_text)
print(f">>>{j}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
# Compare everything except for the audio tokens; we do this since
# the IDs returned from the transformers helper expands the audio
# token to match features, while the vLLM helper maintains the
# single audio token in the input text
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
filtered_vllm_output_ids = seq_with_no_audio_toks(
vllm_output_ids[j])
# HF output IDs may contain the end of sequence
if len(filtered_hf_output_ids
) == len(filtered_vllm_output_ids) + 1:
assert filtered_hf_output_ids[-1] == eos_token_id
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
assert filtered_hf_output_ids == filtered_vllm_output_ids
......@@ -64,9 +64,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed)
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\
should_modify_greedy_probs_inplace = True
scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True
vocab_size = scorer_worker.vocab_size
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import asdict
from dataclasses import MISSING, Field, asdict, dataclass, field
import pytest
import os
from vllm.config import ModelConfig, PoolerConfig
from vllm.config import ModelConfig, PoolerConfig, get_field
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
from utils import models_path_prefix
def test_get_field():
@dataclass
class TestConfig:
a: int
b: dict = field(default_factory=dict)
c: str = "default"
with pytest.raises(ValueError):
get_field(TestConfig, "a")
b = get_field(TestConfig, "b")
assert isinstance(b, Field)
assert b.default is MISSING
assert b.default_factory is dict
c = get_field(TestConfig, "c")
assert isinstance(c, Field)
assert c.default == "default"
assert c.default_factory is MISSING
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
......
......@@ -13,11 +13,11 @@ import torch
from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
PlaceholderModule, StoreBoolean, bind_kv_cache,
deprecate_kwargs, get_open_port, memory_profiling,
merge_async_iterators, sha256, supports_kw,
swap_dict_values)
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, sha256,
supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning
from .utils import models_path_prefix
......@@ -419,6 +419,129 @@ def test_bind_kv_cache_pp():
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
class TestLRUCache(LRUCache):
def _on_remove(self, key, value):
if not hasattr(self, "_remove_counter"):
self._remove_counter = 0
self._remove_counter += 1
def test_lru_cache():
cache = TestLRUCache(3)
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
cache.put(1, 1)
assert len(cache) == 1
cache.put(1, 1)
assert len(cache) == 1
cache.put(2, 2)
assert len(cache) == 2
cache.put(3, 3)
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache.put(4, 4)
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache.get(2) == 2
assert cache.stat() == CacheInfo(hits=1, total=1)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
assert cache[2] == 2
assert cache.stat() == CacheInfo(hits=2, total=2)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
cache.put(5, 5)
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
assert cache.pop(5) == 5
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
assert cache.get(-1) is None
assert cache.stat() == CacheInfo(hits=2, total=3)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.get(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.put(6, 6)
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
cache.remove_oldest()
assert len(cache) == 2
assert set(cache.cache) == {2, 6}
assert cache._remove_counter == 4
cache.clear()
assert len(cache) == 0
assert cache._remove_counter == 6
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
cache._remove_counter = 0
cache[1] = 1
assert len(cache) == 1
cache[1] = 1
assert len(cache) == 1
cache[2] = 2
assert len(cache) == 2
cache[3] = 3
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache[4] = 4
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache[2] == 2
cache[5] = 5
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
del cache[5]
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache[6] = 6
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")
......
# SPDX-License-Identifier: Apache-2.0
import pickle
from copy import deepcopy
import os
import pytest
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from ..utils import models_path_prefix
def test_cached_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2"))
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
@pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "gpt2"), os.path.join(models_path_prefix, "THUDM/chatglm3-6b")])
def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
trust_remote_code=True)
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
_check_consistency(cached_tokenizer, reference_tokenizer)
pickled_tokenizer = pickle.dumps(cached_tokenizer)
unpickled_tokenizer = pickle.loads(pickled_tokenizer)
_check_consistency(unpickled_tokenizer, reference_tokenizer)
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
assert isinstance(target, type(expected))
# Cached attributes
assert target.all_special_ids == expected.all_special_ids
assert target.all_special_tokens == expected.all_special_tokens
assert (target.all_special_tokens_extended ==
expected.all_special_tokens_extended)
assert target.get_vocab() == expected.get_vocab()
assert len(target) == len(expected)
# Other attributes
assert getattr(target, "padding_side",
None) == getattr(expected, "padding_side", None)
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
"prompt")
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids)
assert set(reference_tokenizer.all_special_tokens) == set(
cached_tokenizer.all_special_tokens)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)
assert target.encode("prompt") == expected.encode("prompt")
......@@ -3,19 +3,26 @@
from collections.abc import Generator
from typing import Any, Optional
import pytest
import os
from transformers import AutoTokenizer
import pytest
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer)
from ..utils import models_path_prefix
SPECIAL_TOKS_TRUTH = [
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
]
TRUTH = [
"Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
......@@ -25,7 +32,8 @@ TRUTH = [
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်",
]
] + SPECIAL_TOKS_TRUTH
TOKENIZERS = [
os.path.join(models_path_prefix, "facebook/opt-125m"),
os.path.join(models_path_prefix, "gpt2"),
......@@ -41,26 +49,37 @@ TOKENIZERS = [
]
def _run_incremental_decode(tokenizer, all_input_ids,
skip_special_tokens: bool, starting_index: int):
decoded_text = ""
offset = 0
token_offset = 0
prev_tokens = None
for i in range(starting_index, len(all_input_ids)):
new_tokens, text, offset, token_offset = detokenize_incrementally(
tokenizer,
all_input_ids[:i + 1],
prev_tokens,
offset,
token_offset,
skip_special_tokens=skip_special_tokens)
decoded_text += text
if prev_tokens is None:
prev_tokens = new_tokens
else:
prev_tokens += new_tokens
return decoded_text
def _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens: bool,
starting_index: int,
spaces_between_special_tokens: bool = True,
fast: Optional[bool] = None):
prompt_token_ids = all_input_ids[:starting_index]
params = SamplingParams(
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest("", prompt_token_ids, None, None, None, params,
None, 0.0, None)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer, request)
elif fast:
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
else:
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
output_text = ""
for i, token_id in enumerate(all_input_ids[starting_index:]):
detokenizer.update([token_id], False)
finished = i == len(all_input_ids) - 1
output_text += detokenizer.get_next_output_text(finished, delta=True)
return output_text, detokenizer.output_token_ids
@pytest.fixture
......@@ -88,11 +107,13 @@ def test_mistral_edge_case(tokenizer, truth):
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
decoded_text = _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
assert decoded_text == truth
assert out_ids == all_input_ids[starting_index:]
@pytest.fixture
......@@ -109,45 +130,91 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
spaces_between_special_tokens, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
if skip_special_tokens and not spaces_between_special_tokens:
pytest.skip()
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer.add_special_tokens({
"additional_special_tokens": [
at for at in
tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
})
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
else {"spaces_between_special_tokens": spaces_between_special_tokens}
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
if tokenizer.bos_token_id is not None:
truth_tokens.insert(0, tokenizer.bos_token_id)
truth_tokens.append(tokenizer.eos_token_id)
new_truth = tokenizer.decode(truth_tokens,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
if with_prompt:
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
prompt_input_ids = truth_tokens[:len(truth) // 2]
generated_input_ids = truth_tokens[len(truth) // 2:]
num_prompt_tokens = len(
tokenizer(truth[:len(truth) // 2],
add_special_tokens=False).input_ids)
if tokenizer.bos_token_id is not None:
num_prompt_tokens += 1
prompt_input_ids = truth_tokens[:num_prompt_tokens]
generated_input_ids = truth_tokens[num_prompt_tokens:]
all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens)
generated = truth[len(prompt):]
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
generated = new_truth[len(prompt):]
else:
generated = truth
generated = new_truth
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
if skip_special_tokens:
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
starting_index += 1
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
all_input_ids = truth_tokens
decoded_text = _run_incremental_decode(
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens=skip_special_tokens,
starting_index=starting_index)
starting_index=starting_index,
spaces_between_special_tokens=spaces_between_special_tokens,
fast=fast)
assert decoded_text == generated
assert out_ids == all_input_ids[starting_index:]
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
decoded_text = _run_incremental_decode(
decoded_text, out_ids = _run_incremental_decode(
tokenizer, [len(tokenizer)],
skip_special_tokens=skip_special_tokens,
starting_index=starting_index)
skip_special_tokens=True,
starting_index=0,
spaces_between_special_tokens=True,
fast=fast)
assert decoded_text == ''
assert out_ids == [len(tokenizer)]
@pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer:
init_kwargs = dict(
tokenizer_group = TokenizerGroup(
tokenizer_id=tokenizer_name,
enable_lora=False,
max_num_seqs=100,
......@@ -157,26 +224,20 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
revision=None,
)
tokenizer_group = get_tokenizer_group(
None,
**init_kwargs,
)
return Detokenizer(tokenizer_group)
@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> list[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
return complete_sequence_token_ids
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
prompt_token_ids = prompt_token_ids or []
return Sequence(
seq_id=0,
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
inputs=token_inputs(prompt_token_ids),
block_size=16,
)
......@@ -227,7 +288,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token)
if skip_special_tokens:
if not skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# special tokens.
......@@ -236,10 +297,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if complete_sequence not in SPECIAL_TOKS_TRUTH:
skip_special_tokens = True
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
MistralTokenizer):
skip_special_tokens = False
else:
pytest.skip("MistralTokenizers don't support "
"skip_special_tokens=False")
return
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=True,
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1)
# Run sequentially.
......@@ -259,8 +333,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
text_full = tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
text_first = tokenizer.decode(token_ids[0],
skip_special_tokens=skip_special_tokens)
text = text_full[len(text_first):]
# Text for logprobs for the chosen token should be the same as the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment