Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
...@@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output( ...@@ -68,6 +68,17 @@ def qwen2_vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs return output_ids, hf_output_str, out_logprobs
def kimiv_vl_vllm_to_hf_output(
vllm_output: RunnerOutput,
model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]:
"""Sanitize vllm output [kimi_vl models] to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "<|im_end|>[EOS]"
return output_ids, hf_output_str, out_logprobs
def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput,
model: str) -> RunnerOutput: model: str) -> RunnerOutput:
config = AutoConfig.from_pretrained(model) config = AutoConfig.from_pretrained(model)
......
...@@ -59,24 +59,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch): ...@@ -59,24 +59,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch):
def server_embedding(): def server_embedding():
# GritLM embedding implementation is only supported by XFormers backend. # GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with pytest.MonkeyPatch.context() as m:
yield remote_server m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server_generate(): def server_generate():
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with pytest.MonkeyPatch.context() as m:
yield remote_server m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def client_embedding(monkeypatch: pytest.MonkeyPatch, async def client_embedding(server_embedding: RemoteOpenAIServer):
server_embedding: RemoteOpenAIServer): async with server_embedding.get_async_client() as async_client:
with monkeypatch.context() as m: yield async_client
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
async with server_embedding.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture @pytest_asyncio.fixture
......
...@@ -153,14 +153,24 @@ def test_matryoshka( ...@@ -153,14 +153,24 @@ def test_matryoshka(
with vllm_runner(model, task="embed", dtype=dtype, with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model: max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode( matryoshka_dimensions = (
example_prompts, vllm_model.model.llm_engine.model_config.matryoshka_dimensions)
pooling_params=PoolingParams(dimensions=dimensions)) assert matryoshka_dimensions is not None
check_embeddings_close( if dimensions not in matryoshka_dimensions:
embeddings_0_lst=hf_outputs, with pytest.raises(ValueError):
embeddings_1_lst=vllm_outputs, vllm_model.encode(
name_0="hf", example_prompts,
name_1="vllm", pooling_params=PoolingParams(dimensions=dimensions))
tol=1e-2, else:
) vllm_outputs = vllm_model.encode(
example_prompts,
pooling_params=PoolingParams(dimensions=dimensions))
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
# SPDX-License-Identifier: Apache-2.0
"""Compare the embedding outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`.
"""
import pytest
from tests.models.embedding.utils import EmbedModelInfo
from ..utils import check_embeddings_close
EMBEDDING_PROMPTS = [
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
'Mexico City of Course!'
]
MODELS = [
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
is_matryoshka=False,
architecture="BertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
is_matryoshka=False,
architecture="NomicBertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
is_matryoshka=False,
architecture="BertModel",
enable_test=False),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
is_matryoshka=True,
architecture="BertModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
is_matryoshka=True,
architecture="XLMRobertaModel",
enable_test=True),
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
is_matryoshka=True,
architecture="GteModel",
enable_test=True),
]
@pytest.mark.parametrize("model_info", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model_info: EmbedModelInfo,
dtype: str,
monkeypatch,
) -> None:
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
example_prompts = example_prompts + EMBEDDING_PROMPTS
vllm_extra_kwargs = {
"hf_overrides": {
"is_matryoshka": model_info.is_matryoshka
}
}
with hf_runner(model_info.name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model_info.name,
task="embed",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
assert (vllm_model.model.llm_engine.model_config.is_matryoshka ==
model_info.is_matryoshka)
if model_info.architecture:
assert (model_info.architecture
in vllm_model.model.llm_engine.model_config.architectures)
vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence from collections.abc import Sequence
from typing import NamedTuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -37,3 +38,29 @@ def matryoshka_fy(tensor, dimensions): ...@@ -37,3 +38,29 @@ def matryoshka_fy(tensor, dimensions):
tensor = tensor[..., :dimensions] tensor = tensor[..., :dimensions]
tensor = F.normalize(tensor, p=2, dim=1) tensor = F.normalize(tensor, p=2, dim=1)
return tensor return tensor
class EmbedModelInfo(NamedTuple):
name: str
is_matryoshka: bool
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
enable_test: bool = True
def correctness_test(hf_model,
inputs,
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None):
hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
...@@ -15,12 +15,12 @@ from ...utils import check_logprobs_close ...@@ -15,12 +15,12 @@ from ...utils import check_logprobs_close
from ....utils import models_path_prefix from ....utils import models_path_prefix
MODELS = [os.path.join(models_path_prefix, "microsoft/Florence-2-base")] MODELS = [os.path.join(models_path_prefix, "microsoft/Florence-2-base")]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Florence-2 model repo's tokenizer config is missing some special tokens.
# Therefore, we borrow the BartTokenizer from the original Bart model # Therefore, we use a converted tokenizer from a forked repo
TOKENIZER = os.path.join(models_path_prefix, "facebook/bart-base") TOKENIZER = os.path.join(models_path_prefix, "Isotr0py/Florence-2-tokenizer")
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
"<CAPTION>", # special task token "<OD>", # special task token which will output special tokens
"cherry_blossom": "cherry_blossom":
"Describe in detail what is shown in the image.", "Describe in detail what is shown in the image.",
}) })
...@@ -47,7 +47,6 @@ def hf_to_vllm_output(hf_output: tuple[list[int], str, ...@@ -47,7 +47,6 @@ def hf_to_vllm_output(hf_output: tuple[list[int], str,
output_ids, output_str, out_logprobs = hf_output output_ids, output_str, out_logprobs = hf_output
output_str = output_str.replace("</s>", "").replace("<s>", "") output_str = output_str.replace("</s>", "").replace("<s>", "")
output_ids = [ids for ids in output_ids if ids not in [0, 2]]
return output_ids, output_str, out_logprobs return output_ids, output_str, out_logprobs
...@@ -73,8 +72,11 @@ def run_test( ...@@ -73,8 +72,11 @@ def run_test(
enforce_eager=True) as vllm_model: enforce_eager=True) as vllm_model:
vllm_outputs_per_case = [ vllm_outputs_per_case = [
vllm_model.generate_encoder_decoder_greedy_logprobs( vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs=num_logprobs) prompts,
for prompts in inputs max_tokens,
num_logprobs=num_logprobs,
skip_special_tokens=False,
) for prompts in inputs
] ]
hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs] hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
...@@ -95,6 +97,7 @@ def run_test( ...@@ -95,6 +97,7 @@ def run_test(
outputs_1_lst=vllm_outputs, outputs_1_lst=vllm_outputs,
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
num_outputs_0_skip_tokens=1,
) )
......
...@@ -254,10 +254,12 @@ def _test_processing_correctness_mistral( ...@@ -254,10 +254,12 @@ def _test_processing_correctness_mistral(
"adept/fuyu-8b", "adept/fuyu-8b",
"google/gemma-3-4b-it", "google/gemma-3-4b-it",
"THUDM/glm-4v-9b", "THUDM/glm-4v-9b",
"ibm-granite/granite-speech-3.3-8b",
"h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3", "HuggingFaceM4/Idefics3-8B-Llama3",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct", "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"moonshotai/Kimi-VL-A3B-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf", "llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf",
...@@ -273,12 +275,14 @@ def _test_processing_correctness_mistral( ...@@ -273,12 +275,14 @@ def _test_processing_correctness_mistral(
"nvidia/NVLM-D-72B", "nvidia/NVLM-D-72B",
"google/paligemma-3b-mix-224", "google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448", "google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409", "mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat", "Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-7B",
"Skywork/Skywork-R1V-38B", "Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b", "fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3", "openai/whisper-large-v3",
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for phi4mm's multimodal preprocessing kwargs."""
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"])
# yapf: disable
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
[
({"dynamic_hd": 4}, 1329),
({"dynamic_hd": 16}, 4433),
# the default num_crops of phi-4-multimodal is 36
({}, 9585),
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
image_size = ctx.get_hf_config(
).embd_layer["image_embd_layer"]["crop_size"]
dummy_image_size = (image_size * 7, image_size * 7)
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count(
_IMAGE_PLACEHOLDER_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs
...@@ -121,9 +121,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -121,9 +121,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
trust_remote_code=True), trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"),
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
{"1b": "bigscience/bloomz-1b1"}),
"ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b", "ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b",
trust_remote_code=True), trust_remote_code=True,
max_transformers_version="4.48"),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501 "ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
...@@ -141,24 +143,26 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -141,24 +143,26 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it", "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
min_transformers_version="4.50"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo( "Glm4ForCausalLM": _HfExamplesInfo(
"THUDM/GLM-4-32B-0414", "THUDM/GLM-4-32B-0414",
is_available_online=False, is_available_online=False,
min_transformers_version="4.52.dev0" min_transformers_version="4.52.dev0"
), ),
"GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), {"alias": "gpt2"}),
"GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder",
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), {"tiny": "bigcode/tiny_starcoder_py"}), # noqa: E501
"GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M",
{"6b": "EleutherAI/gpt-j-6b"}),
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m",
{"1b": "EleutherAI/pythia-1.4b"}),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts", # noqa: E501 "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
trust_remote_code=True), trust_remote_code=True),
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
...@@ -186,7 +190,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -186,7 +190,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True), trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501
{"falcon3": "ehristoforu/Falcon3-MoE-2x7B-Insruct"}), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
...@@ -194,7 +199,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -194,7 +199,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("shanearora/OLMo-7B-1124-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("shanearora/OLMo-7B-1124-hf"),
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"), "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m",
{"1b": "facebook/opt-iml-max-1.3b"}),
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True), trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
...@@ -204,10 +210,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -204,10 +210,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True), trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
trust_remote_code=True), trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct", "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct",
extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501 extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo( "Qwen3ForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-8B", "Qwen/Qwen3-8B",
...@@ -233,8 +241,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -233,8 +241,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False, is_available_online=False,
trust_remote_code=True), trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct", "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
min_transformers_version="4.49"),
# [Encoder-decoder] # [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"), "BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
...@@ -245,11 +252,15 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -245,11 +252,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True), trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
...@@ -273,6 +284,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = { ...@@ -273,6 +284,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
} }
_MULTIMODAL_EXAMPLE_MODELS = { _MULTIMODAL_EXAMPLE_MODELS = {
...@@ -286,10 +298,11 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -286,10 +298,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501 max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it", "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
min_transformers_version="4.50"), "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501
min_transformers_version="4.52.0"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b", "GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True, trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
...@@ -302,6 +315,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -302,6 +315,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
min_transformers_version="4.51"), min_transformers_version="4.51"),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
...@@ -322,7 +338,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -322,7 +338,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
min_transformers_version="4.50", # noqa: E501
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
max_transformers_version="4.48", max_transformers_version="4.48",
...@@ -348,8 +363,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -348,8 +363,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501
min_transformers_version="4.49"), # noqa: E501 "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501
min_transformers_version="4.52"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
...@@ -358,7 +374,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -358,7 +374,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model # Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="facebook/bart-base", tokenizer="Isotr0py/Florence-2-tokenizer",
trust_remote_code=True), # noqa: E501 trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
...@@ -379,6 +395,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -379,6 +395,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
"Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
} }
_TRANSFORMERS_MODELS = { _TRANSFORMERS_MODELS = {
......
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of a GPTQ model to a bitblas model.
Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.
Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.
Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass
import pytest
from .utils import check_logprobs_close
@dataclass
class ModelPair:
model_bitblas: str
model_gptq: str
model_pairs = [
ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas",
model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_bitblas,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="bitblas",
)
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of a GPTQ model to a bitblas model.
Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.
Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.
Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass
import pytest
from .utils import check_logprobs_close
@dataclass
class ModelPair:
model_gptq: str
model_pairs = [
ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_gptq,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="gptq_bitblas",
)
...@@ -24,10 +24,7 @@ def test_can_initialize(model_arch): ...@@ -24,10 +24,7 @@ def test_can_initialize(model_arch):
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides) hf_config.update(model_info.hf_overrides)
if hasattr(hf_config, "text_config"): text_config = hf_config.get_text_config()
text_config: PretrainedConfig = hf_config.text_config
else:
text_config = hf_config
text_config.update({ text_config.update({
"num_layers": 1, "num_layers": 1,
......
...@@ -18,10 +18,9 @@ def test_plugin( ...@@ -18,10 +18,9 @@ def test_plugin(
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "0")
m.setenv("VLLM_PLUGINS", "") m.setenv("VLLM_PLUGINS", "")
with pytest.raises(Exception) as excinfo: match = "Cannot find model module"
with pytest.raises(ValueError, match=match):
LLM(model=dummy_opt_path, load_format="dummy") LLM(model=dummy_opt_path, load_format="dummy")
error_msg = "has no vLLM implementation and the Transformers implementation is not compatible with vLLM" # noqa: E501
assert (error_msg in str(excinfo.value))
@create_new_process_for_each_test() @create_new_process_for_each_test()
......
...@@ -266,16 +266,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token( ...@@ -266,16 +266,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
reason="WNA16 is not supported on ROCm.") reason="WNA16 is not supported on ROCm.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"wNa16_args", "wNa16_args",
[ [(os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8,
(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8), True, False),
(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-group128-v2"), "group", 128, 8), (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w4a16-group128-v2"), "group", 128, 8, True,
(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w8a16-per-channel"), "channel", None, 4), False),
], (os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8a16-per-channel"), "channel", None, 4,
True, False),
(os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256"), "group", 128,
8, False, False),
(os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel"),
"channel", None, 8, False, False),
(os.path.join(models_path_prefix, "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder"),
"group", 128, 8, False, True)],
) )
@pytest.mark.skipif(not current_platform.is_cuda(), @pytest.mark.skipif(not current_platform.is_cuda(),
reason="The tests are skipped on non-CUDA platform.") reason="The tests are skipped on non-CUDA platform.")
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
with vllm_runner(model) as llm: with vllm_runner(model) as llm:
def check_model(model): def check_model(model):
...@@ -291,6 +298,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): ...@@ -291,6 +298,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
if group is None else group) if group is None else group)
assert qkv_proj.scheme.pack_factor == pack_factor assert qkv_proj.scheme.pack_factor == pack_factor
assert qkv_proj.scheme.symmetric == symmetric
assert qkv_proj.scheme.has_g_idx == has_g_idx
llm.apply_model(check_model) llm.apply_model(check_model)
......
...@@ -7,6 +7,9 @@ Run `pytest tests/samplers/test_beam_search.py`. ...@@ -7,6 +7,9 @@ Run `pytest tests/samplers/test_beam_search.py`.
import pytest import pytest
import os import os
from ..utils import models_path_prefix from ..utils import models_path_prefix
from transformers import AutoModelForSeq2SeqLM
from vllm.assets.audio import AudioAsset
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -21,7 +24,8 @@ def v1(run_with_both_engines): ...@@ -21,7 +24,8 @@ def v1(run_with_both_engines):
# 3. Use the model "huggyllama/llama-7b". # 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [64] MAX_TOKENS = [64]
BEAM_WIDTHS = [4] BEAM_WIDTHS = [4]
MODELS = [os.path.join(models_path_prefix, "TinyLlama/TinyLlama-1.1B-Chat-v1.0")] MM_BEAM_WIDTHS = [2]
MODELS = [os.path.join(models_path_prefix, "TinyLlama/TinyLlama-1.1B-Chat-v1.0")]
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now. @pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
...@@ -50,15 +54,90 @@ def test_beam_search_single_input( ...@@ -50,15 +54,90 @@ def test_beam_search_single_input(
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i] hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for i, (hf_text, for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts, vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)): vllm_output_texts)):
print(f">>>{i}-th hf output:") print(f">>>{j}-th hf output:")
print(hf_text) print(hf_text)
print(f">>>{i}-th vllm output:") print(f">>>{j}-th vllm output:")
print(vllm_text) print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids) assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)): for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], ( assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}") f"vLLM: {vllm_output_ids}")
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
def test_beam_search_passes_multimodal_data(
hf_runner,
vllm_runner,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Ensure that beam search passes multimodal data through correctly."""
# NOTE - this test is primarily to check that mm data is passed to beams
# correctly. As such, we just need to check one extra modality to make
# sure things pass through properly.
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
model = "Qwen/Qwen2-Audio-7B-Instruct"
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
prompts = [
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501
]
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
audio_token_id = hf_model.config.audio_token_index
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
hf_outputs = hf_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)
seq_with_no_audio_toks = lambda seq: [
tok for tok in seq if tok != audio_token_id
]
for i in range(len(prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
print(hf_text)
print(f">>>{j}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
# Compare everything except for the audio tokens; we do this since
# the IDs returned from the transformers helper expands the audio
# token to match features, while the vLLM helper maintains the
# single audio token in the input text
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
filtered_vllm_output_ids = seq_with_no_audio_toks(
vllm_output_ids[j])
# HF output IDs may contain the end of sequence
if len(filtered_hf_output_ids
) == len(filtered_vllm_output_ids) + 1:
assert filtered_hf_output_ids[-1] == eos_token_id
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
assert filtered_hf_output_ids == filtered_vllm_output_ids
...@@ -64,9 +64,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int, ...@@ -64,9 +64,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
scorer_worker = create_worker(Worker, model_name, block_size, scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed) num_gpu_blocks, seed)
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\ scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True
should_modify_greedy_probs_inplace = True
vocab_size = scorer_worker.vocab_size vocab_size = scorer_worker.vocab_size
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import asdict from dataclasses import MISSING, Field, asdict, dataclass, field
import pytest import pytest
import os import os
from vllm.config import ModelConfig, PoolerConfig from vllm.config import ModelConfig, PoolerConfig, get_field
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
from utils import models_path_prefix from utils import models_path_prefix
def test_get_field():
@dataclass
class TestConfig:
a: int
b: dict = field(default_factory=dict)
c: str = "default"
with pytest.raises(ValueError):
get_field(TestConfig, "a")
b = get_field(TestConfig, "b")
assert isinstance(b, Field)
assert b.default is MISSING
assert b.default_factory is dict
c = get_field(TestConfig, "c")
assert isinstance(c, Field)
assert c.default == "default"
assert c.default_factory is MISSING
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"), ("model_id", "expected_runner_type", "expected_task"),
[ [
......
...@@ -13,11 +13,11 @@ import torch ...@@ -13,11 +13,11 @@ import torch
from vllm_test_utils.monitor import monitor from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
PlaceholderModule, StoreBoolean, bind_kv_cache, MemorySnapshot, PlaceholderModule, StoreBoolean,
deprecate_kwargs, get_open_port, memory_profiling, bind_kv_cache, deprecate_kwargs, get_open_port,
merge_async_iterators, sha256, supports_kw, memory_profiling, merge_async_iterators, sha256,
swap_dict_values) supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning from .utils import create_new_process_for_each_test, error_on_warning
from .utils import models_path_prefix from .utils import models_path_prefix
...@@ -419,6 +419,129 @@ def test_bind_kv_cache_pp(): ...@@ -419,6 +419,129 @@ def test_bind_kv_cache_pp():
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
class TestLRUCache(LRUCache):
def _on_remove(self, key, value):
if not hasattr(self, "_remove_counter"):
self._remove_counter = 0
self._remove_counter += 1
def test_lru_cache():
cache = TestLRUCache(3)
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
cache.put(1, 1)
assert len(cache) == 1
cache.put(1, 1)
assert len(cache) == 1
cache.put(2, 2)
assert len(cache) == 2
cache.put(3, 3)
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache.put(4, 4)
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache.get(2) == 2
assert cache.stat() == CacheInfo(hits=1, total=1)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
assert cache[2] == 2
assert cache.stat() == CacheInfo(hits=2, total=2)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
cache.put(5, 5)
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
assert cache.pop(5) == 5
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
assert cache.get(-1) is None
assert cache.stat() == CacheInfo(hits=2, total=3)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.get(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.put(6, 6)
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
cache.remove_oldest()
assert len(cache) == 2
assert set(cache.cache) == {2, 6}
assert cache._remove_counter == 4
cache.clear()
assert len(cache) == 0
assert cache._remove_counter == 6
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
cache._remove_counter = 0
cache[1] = 1
assert len(cache) == 1
cache[1] = 1
assert len(cache) == 1
cache[2] = 2
assert len(cache) == 2
cache[3] = 3
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache[4] = 4
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache[2] == 2
cache[5] = 5
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
del cache[5]
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache[6] = 6
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
def test_placeholder_module_error_handling(): def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234") placeholder = PlaceholderModule("placeholder_1234")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle
from copy import deepcopy from copy import deepcopy
import os import os
import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from ..utils import models_path_prefix from ..utils import models_path_prefix
def test_cached_tokenizer(): from vllm.transformers_utils.tokenizer import (AnyTokenizer,
reference_tokenizer = AutoTokenizer.from_pretrained(os.path.join(models_path_prefix, "gpt2")) get_cached_tokenizer)
@pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "gpt2"), os.path.join(models_path_prefix, "THUDM/chatglm3-6b")])
def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
trust_remote_code=True)
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"}) reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens( reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]}) {"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
_check_consistency(cached_tokenizer, reference_tokenizer)
pickled_tokenizer = pickle.dumps(cached_tokenizer)
unpickled_tokenizer = pickle.loads(pickled_tokenizer)
_check_consistency(unpickled_tokenizer, reference_tokenizer)
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
assert isinstance(target, type(expected))
# Cached attributes
assert target.all_special_ids == expected.all_special_ids
assert target.all_special_tokens == expected.all_special_tokens
assert (target.all_special_tokens_extended ==
expected.all_special_tokens_extended)
assert target.get_vocab() == expected.get_vocab()
assert len(target) == len(expected)
# Other attributes
assert getattr(target, "padding_side",
None) == getattr(expected, "padding_side", None)
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( assert target.encode("prompt") == expected.encode("prompt")
"prompt")
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids)
assert set(reference_tokenizer.all_special_tokens) == set(
cached_tokenizer.all_special_tokens)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)
...@@ -3,19 +3,26 @@ ...@@ -3,19 +3,26 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
import pytest
import os import os
from transformers import AutoTokenizer import pytest
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.inputs import token_inputs from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer, from vllm.transformers_utils.detokenizer import Detokenizer
detokenize_incrementally) from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer)
from ..utils import models_path_prefix from ..utils import models_path_prefix
SPECIAL_TOKS_TRUTH = [
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
]
TRUTH = [ TRUTH = [
"Hello here, this is a simple test", "Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
...@@ -25,7 +32,8 @@ TRUTH = [ ...@@ -25,7 +32,8 @@ TRUTH = [
# incomplete UTF-8 characters # incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625 # see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်", "ပုံပြင်လေးပြောပြပါ်",
] ] + SPECIAL_TOKS_TRUTH
TOKENIZERS = [ TOKENIZERS = [
os.path.join(models_path_prefix, "facebook/opt-125m"), os.path.join(models_path_prefix, "facebook/opt-125m"),
os.path.join(models_path_prefix, "gpt2"), os.path.join(models_path_prefix, "gpt2"),
...@@ -41,26 +49,37 @@ TOKENIZERS = [ ...@@ -41,26 +49,37 @@ TOKENIZERS = [
] ]
def _run_incremental_decode(tokenizer, all_input_ids, def _run_incremental_decode(tokenizer,
skip_special_tokens: bool, starting_index: int): all_input_ids,
decoded_text = "" skip_special_tokens: bool,
offset = 0 starting_index: int,
token_offset = 0 spaces_between_special_tokens: bool = True,
prev_tokens = None fast: Optional[bool] = None):
for i in range(starting_index, len(all_input_ids)):
new_tokens, text, offset, token_offset = detokenize_incrementally( prompt_token_ids = all_input_ids[:starting_index]
tokenizer,
all_input_ids[:i + 1], params = SamplingParams(
prev_tokens, skip_special_tokens=skip_special_tokens,
offset, spaces_between_special_tokens=spaces_between_special_tokens,
token_offset, )
skip_special_tokens=skip_special_tokens) request = EngineCoreRequest("", prompt_token_ids, None, None, None, params,
decoded_text += text None, 0.0, None)
if prev_tokens is None:
prev_tokens = new_tokens if fast is None:
else: detokenizer = IncrementalDetokenizer.from_new_request(
prev_tokens += new_tokens tokenizer, request)
return decoded_text elif fast:
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
else:
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
output_text = ""
for i, token_id in enumerate(all_input_ids[starting_index:]):
detokenizer.update([token_id], False)
finished = i == len(all_input_ids) - 1
output_text += detokenizer.get_next_output_text(finished, delta=True)
return output_text, detokenizer.output_token_ids
@pytest.fixture @pytest.fixture
...@@ -88,11 +107,13 @@ def test_mistral_edge_case(tokenizer, truth): ...@@ -88,11 +107,13 @@ def test_mistral_edge_case(tokenizer, truth):
starting_index = 0 starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
decoded_text = _run_incremental_decode(tokenizer, decoded_text, out_ids = _run_incremental_decode(
all_input_ids, tokenizer,
skip_special_tokens=True, all_input_ids,
starting_index=starting_index) skip_special_tokens=True,
starting_index=starting_index)
assert decoded_text == truth assert decoded_text == truth
assert out_ids == all_input_ids[starting_index:]
@pytest.fixture @pytest.fixture
...@@ -109,45 +130,91 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: ...@@ -109,45 +130,91 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@pytest.mark.parametrize("with_prompt", [True, False]) @pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) @pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens): @pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
spaces_between_special_tokens, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
if skip_special_tokens and not spaces_between_special_tokens:
pytest.skip()
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer.add_special_tokens({
"additional_special_tokens": [
at for at in
tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
})
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
else {"spaces_between_special_tokens": spaces_between_special_tokens}
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
if tokenizer.bos_token_id is not None:
truth_tokens.insert(0, tokenizer.bos_token_id)
truth_tokens.append(tokenizer.eos_token_id)
new_truth = tokenizer.decode(truth_tokens,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
if with_prompt: if with_prompt:
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids num_prompt_tokens = len(
prompt_input_ids = truth_tokens[:len(truth) // 2] tokenizer(truth[:len(truth) // 2],
generated_input_ids = truth_tokens[len(truth) // 2:] add_special_tokens=False).input_ids)
if tokenizer.bos_token_id is not None:
num_prompt_tokens += 1
prompt_input_ids = truth_tokens[:num_prompt_tokens]
generated_input_ids = truth_tokens[num_prompt_tokens:]
all_input_ids = prompt_input_ids + generated_input_ids all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids) starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids, prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens,
generated = truth[len(prompt):] **extra_decode_args)
generated = new_truth[len(prompt):]
else: else:
generated = truth generated = new_truth
starting_index = 0 starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids all_input_ids = truth_tokens
if skip_special_tokens:
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
starting_index += 1
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
decoded_text = _run_incremental_decode( decoded_text, out_ids = _run_incremental_decode(
tokenizer, tokenizer,
all_input_ids, all_input_ids,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
starting_index=starting_index) starting_index=starting_index,
spaces_between_special_tokens=spaces_between_special_tokens,
fast=fast)
assert decoded_text == generated assert decoded_text == generated
assert out_ids == all_input_ids[starting_index:]
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
decoded_text = _run_incremental_decode( decoded_text, out_ids = _run_incremental_decode(
tokenizer, [len(tokenizer)], tokenizer, [len(tokenizer)],
skip_special_tokens=skip_special_tokens, skip_special_tokens=True,
starting_index=starting_index) starting_index=0,
spaces_between_special_tokens=True,
fast=fast)
assert decoded_text == '' assert decoded_text == ''
assert out_ids == [len(tokenizer)]
@pytest.fixture @pytest.fixture
def detokenizer(tokenizer_name: str) -> Detokenizer: def detokenizer(tokenizer_name: str) -> Detokenizer:
init_kwargs = dict( tokenizer_group = TokenizerGroup(
tokenizer_id=tokenizer_name, tokenizer_id=tokenizer_name,
enable_lora=False, enable_lora=False,
max_num_seqs=100, max_num_seqs=100,
...@@ -157,26 +224,20 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: ...@@ -157,26 +224,20 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
revision=None, revision=None,
) )
tokenizer_group = get_tokenizer_group(
None,
**init_kwargs,
)
return Detokenizer(tokenizer_group) return Detokenizer(tokenizer_group)
@pytest.fixture(name="complete_sequence_token_ids") @pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str, def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> list[int]: tokenizer) -> list[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids return tokenizer(complete_sequence, add_special_tokens=False).input_ids
return complete_sequence_token_ids
def create_sequence(prompt_token_ids=None): def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1] prompt_token_ids = prompt_token_ids or []
return Sequence( return Sequence(
seq_id=0, seq_id=0,
inputs=token_inputs(prompt_token_ids, prompt="<s>"), inputs=token_inputs(prompt_token_ids),
block_size=16, block_size=16,
) )
...@@ -227,7 +288,7 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -227,7 +288,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
assert sequential_result == "".join(sequential_logprobs_text_chosen_token) assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token) assert sequential_result != "".join(sequential_logprobs_text_other_token)
if skip_special_tokens: if not skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the # Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip # generated text. Note that this will only be true if we skip
# special tokens. # special tokens.
...@@ -236,10 +297,23 @@ def test_decode_sequence_logprobs(complete_sequence: str, ...@@ -236,10 +297,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer): detokenizer: Detokenizer):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if complete_sequence not in SPECIAL_TOKS_TRUTH:
skip_special_tokens = True
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
MistralTokenizer):
skip_special_tokens = False
else:
pytest.skip("MistralTokenizers don't support "
"skip_special_tokens=False")
return
"""Verify Detokenizer decodes prompt logprobs correctly.""" """Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=True, sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1) prompt_logprobs=1)
# Run sequentially. # Run sequentially.
...@@ -259,8 +333,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int], ...@@ -259,8 +333,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
# decoded_prompt_logprobs doesn't contain the first token. # decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids token_ids = complete_sequence_token_ids
tokenizer = detokenizer.get_tokenizer_for_seq(seq) tokenizer = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenizer.decode(token_ids, skip_special_tokens=True) text_full = tokenizer.decode(token_ids,
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True) skip_special_tokens=skip_special_tokens)
text_first = tokenizer.decode(token_ids[0],
skip_special_tokens=skip_special_tokens)
text = text_full[len(text_first):] text = text_full[len(text_first):]
# Text for logprobs for the chosen token should be the same as the # Text for logprobs for the chosen token should be the same as the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment