Unverified Commit 4e256cad authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove all references to `yapf` as it's no longer used (#26251)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent d6953beb
......@@ -169,430 +169,625 @@ class _HfExamplesInfo:
pytest.skip(msg)
# yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-2509",
min_transformers_version="4.56.0",
trust_remote_code=True),
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B",
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
trust_remote_code=True),
"ApertusForCausalLM": _HfExamplesInfo(
"swiss-ai/Apertus-8B-2509",
min_transformers_version="4.56.0",
trust_remote_code=True,
),
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
trust_remote_code=True),
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
trust_remote_code=True),
"BailingMoeV2ForCausalLM": _HfExamplesInfo("inclusionAI/Ling-mini-2.0",
trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.55.3",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",
{"1b": "bigscience/bloomz-1b1"}),
"ChatGLMModel": _HfExamplesInfo("zai-org/chatglm3-6b",
trust_remote_code=True,
max_transformers_version="4.48"),
"ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501
trust_remote_code=True),
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
trust_remote_code=True),
"Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501
trust_remote_code=True),
"CwmForCausalLM": _HfExamplesInfo("facebook/cwm", # noqa: E501
trust_remote_code=True,
is_available_online=False),
"ArcticForCausalLM": _HfExamplesInfo(
"Snowflake/snowflake-arctic-instruct", trust_remote_code=True
),
"BaiChuanForCausalLM": _HfExamplesInfo(
"baichuan-inc/Baichuan-7B", trust_remote_code=True
),
"BaichuanForCausalLM": _HfExamplesInfo(
"baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True
),
"BailingMoeForCausalLM": _HfExamplesInfo(
"inclusionAI/Ling-lite-1.5", trust_remote_code=True
),
"BailingMoeV2ForCausalLM": _HfExamplesInfo(
"inclusionAI/Ling-mini-2.0", trust_remote_code=True
),
"BambaForCausalLM": _HfExamplesInfo(
"ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.55.3",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"},
),
"BloomForCausalLM": _HfExamplesInfo(
"bigscience/bloom-560m", {"1b": "bigscience/bloomz-1b1"}
),
"ChatGLMModel": _HfExamplesInfo(
"zai-org/chatglm3-6b", trust_remote_code=True, max_transformers_version="4.48"
),
"ChatGLMForConditionalGeneration": _HfExamplesInfo(
"thu-coai/ShieldLM-6B-chatglm3",
trust_remote_code=True,
),
"CohereForCausalLM": _HfExamplesInfo(
"CohereForAI/c4ai-command-r-v01", trust_remote_code=True
),
"Cohere2ForCausalLM": _HfExamplesInfo(
"CohereForAI/c4ai-command-r7b-12-2024",
trust_remote_code=True,
),
"CwmForCausalLM": _HfExamplesInfo(
"facebook/cwm",
trust_remote_code=True,
is_available_online=False,
),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
"DeciLMForCausalLM": _HfExamplesInfo("nvidia/Llama-3_3-Nemotron-Super-49B-v1", # noqa: E501
trust_remote_code=True),
"DeciLMForCausalLM": _HfExamplesInfo(
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
trust_remote_code=True,
),
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501
trust_remote_code=True),
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"DeepseekV2ForCausalLM": _HfExamplesInfo(
"deepseek-ai/DeepSeek-V2-Lite-Chat",
trust_remote_code=True,
),
"DeepseekV3ForCausalLM": _HfExamplesInfo(
"deepseek-ai/DeepSeek-V3",
trust_remote_code=True,
),
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
min_transformers_version="4.54"),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
min_transformers_version="4.54"),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
trust_remote_code=True),
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B",
min_transformers_version="4.54"),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"Ernie4_5ForCausalLM": _HfExamplesInfo(
"baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54"
),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo(
"baidu/ERNIE-4.5-21B-A3B-PT", min_transformers_version="4.54"
),
"ExaoneForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True
),
"Exaone4ForCausalLM": _HfExamplesInfo(
"LGAI-EXAONE/EXAONE-4.0-32B", min_transformers_version="4.54"
),
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it",
min_transformers_version="4.53"),
"Gemma3nForCausalLM": _HfExamplesInfo(
"google/gemma-3n-E2B-it", min_transformers_version="4.53"
),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
"Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5",
min_transformers_version="4.54"), # noqa: E501
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2",
{"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder",
extras={"tiny": "bigcode/tiny_starcoder_py"}, # noqa: E501
min_transformers_version="4.55.1",
transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501
"GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M",
{"6b": "EleutherAI/gpt-j-6b"}),
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m",
{"1b": "EleutherAI/pythia-1.4b"}),
"Glm4MoeForCausalLM": _HfExamplesInfo(
"zai-org/GLM-4.5", min_transformers_version="4.54"
),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo(
"bigcode/starcoder",
extras={"tiny": "bigcode/tiny_starcoder_py"},
min_transformers_version="4.55.1",
transformers_version_reason="HF model broken in 4.55.0",
),
"GPTJForCausalLM": _HfExamplesInfo(
"Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"}
),
"GPTNeoXForCausalLM": _HfExamplesInfo(
"EleutherAI/pythia-70m", {"1b": "EleutherAI/pythia-1.4b"}
),
"GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501
min_transformers_version="4.55.3"),
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
trust_remote_code=True),
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct",
trust_remote_code=True),
"GraniteMoeHybridForCausalLM": _HfExamplesInfo(
"ibm-granite/granite-4.0-tiny-preview",
min_transformers_version="4.55.3",
),
"GraniteMoeSharedForCausalLM": _HfExamplesInfo(
"ibm-research/moe-7b-1b-active-shared-experts"
),
"Grok1ModelForCausalLM": _HfExamplesInfo(
"hpcai-tech/grok-1", trust_remote_code=True
),
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-A13B-Instruct", trust_remote_code=True
),
# TODO: Remove is_available_online once their config.json is fixed
"HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124",
trust_remote_code=True,
is_available_online=False),
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
trust_remote_code=True),
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
trust_remote_code=True),
"InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B",
trust_remote_code=True),
"InternLM3ForCausalLM": _HfExamplesInfo("internlm/internlm3-8b-instruct",
trust_remote_code=True),
"HunYuanDenseV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-7B-Instruct-0124",
trust_remote_code=True,
is_available_online=False,
),
"InternLMForCausalLM": _HfExamplesInfo(
"internlm/internlm-chat-7b", trust_remote_code=True
),
"InternLM2ForCausalLM": _HfExamplesInfo(
"internlm/internlm2-chat-7b", trust_remote_code=True
),
"InternLM2VEForCausalLM": _HfExamplesInfo(
"OpenGVLab/Mono-InternVL-2B", trust_remote_code=True
),
"InternLM3ForCausalLM": _HfExamplesInfo(
"internlm/internlm3-8b-instruct", trust_remote_code=True
),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
min_transformers_version="4.55.3",
extras={
"tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
}),
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B",
min_transformers_version="4.54"),
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False),
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
is_available_online=False),
"LongcatFlashForCausalLM": _HfExamplesInfo
("meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True),
"JambaForCausalLM": _HfExamplesInfo(
"ai21labs/AI21-Jamba-1.5-Mini",
min_transformers_version="4.55.3",
extras={
"tiny": "ai21labs/Jamba-tiny-dev",
"random": "ai21labs/Jamba-tiny-random",
},
),
"Lfm2ForCausalLM": _HfExamplesInfo(
"LiquidAI/LFM2-1.2B", min_transformers_version="4.54"
),
"LlamaForCausalLM": _HfExamplesInfo(
"meta-llama/Llama-3.2-1B-Instruct",
extras={
"guard": "meta-llama/Llama-Guard-3-1B",
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B",
"fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
},
),
"LLaMAForCausalLM": _HfExamplesInfo(
"decapoda-research/llama-7b-hf", is_available_online=False
),
"Llama4ForCausalLM": _HfExamplesInfo(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
is_available_online=False,
),
"LongcatFlashForCausalLM": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True
),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
extras={
"random": "yujiepan/mamba2-codestral-v0.1-tiny-random", # noqa: E501
}),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"Mamba2ForCausalLM": _HfExamplesInfo(
"mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
extras={
"random": "yujiepan/mamba2-codestral-v0.1-tiny-random",
},
),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"),
"MiniCPMForCausalLM": _HfExamplesInfo(
"openbmb/MiniCPM-2B-sft-bf16", trust_remote_code=True
),
"MiniCPM3ForCausalLM": _HfExamplesInfo(
"openbmb/MiniCPM3-4B", trust_remote_code=True
),
"MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf"),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True,
revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501
"MiniMaxM1ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-M1-40k",
trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo(
"MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True,
revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3",
),
"MiniMaxM1ForCausalLM": _HfExamplesInfo(
"MiniMaxAI/MiniMax-M1-40k", trust_remote_code=True
),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501
{"tiny": "TitanML/tiny-mixtral"}), # noqa: E501
"MixtralForCausalLM": _HfExamplesInfo(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
{"tiny": "TitanML/tiny-mixtral"},
),
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
"NemotronHForCausalLM": _HfExamplesInfo("nvidia/Nemotron-H-8B-Base-8K",
trust_remote_code=True),
"NemotronHForCausalLM": _HfExamplesInfo(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
),
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
"Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"),
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m",
{"1b": "facebook/opt-iml-max-1.3b"}),
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True),
"OPTForCausalLM": _HfExamplesInfo(
"facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"}
),
"OrionForCausalLM": _HfExamplesInfo(
"OrionStarAI/Orion-14B-Chat", trust_remote_code=True
),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
max_transformers_version="4.55.4",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct",
extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501
"PhiMoEForCausalLM": _HfExamplesInfo(
"microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True
),
"Plamo2ForCausalLM": _HfExamplesInfo(
"pfnet/plamo-2-1b",
max_transformers_version="4.55.4",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True,
),
"QWenLMHeadModel": _HfExamplesInfo(
"Qwen/Qwen-7B-Chat",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
trust_remote_code=True,
),
"Qwen2ForCausalLM": _HfExamplesInfo(
"Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}
),
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
min_transformers_version="4.56.3"),
"Qwen3NextForCausalLM": _HfExamplesInfo(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
extras={"tiny-random": "tiny-random/qwen3-next-moe"},
min_transformers_version="4.56.3",
),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True,
is_available_online=False),
"SeedOssForCausalLM": _HfExamplesInfo(
"ByteDance-Seed/Seed-OSS-36B-Instruct",
trust_remote_code=True,
is_available_online=False,
),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3",
trust_remote_code=True),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
trust_remote_code=True),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
trust_remote_code=True),
"TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407",
trust_remote_code=True),
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
tokenizer="meta-llama/Llama-2-7b",
trust_remote_code=True),
"Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", trust_remote_code=True),
"SolarForCausalLM": _HfExamplesInfo(
"upstage/solar-pro-preview-instruct", trust_remote_code=True
),
"TeleChat2ForCausalLM": _HfExamplesInfo(
"Tele-AI/TeleChat2-3B", trust_remote_code=True
),
"TeleFLMForCausalLM": _HfExamplesInfo(
"CofeAI/FLM-2-52B-Instruct-2407", trust_remote_code=True
),
"XverseForCausalLM": _HfExamplesInfo(
"xverse/XVERSE-7B-Chat",
tokenizer="meta-llama/Llama-2-7b",
trust_remote_code=True,
),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True),
"Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"),
}
_EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
"GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5",
trust_remote_code=True,
hf_overrides={"architectures": ["GteNewModel"]}), # noqa: E501
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"GteModel": _HfExamplesInfo(
"Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True
),
"GteNewModel": _HfExamplesInfo(
"Alibaba-NLP/gte-base-en-v1.5",
trust_remote_code=True,
hf_overrides={"architectures": ["GteNewModel"]},
),
"InternLM2ForRewardModel": _HfExamplesInfo(
"internlm/internlm2-1_8b-reward", trust_remote_code=True
),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"),
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True), # noqa: E501
"ModernBertModel": _HfExamplesInfo(
"Alibaba-NLP/gte-modernbert-base", trust_remote_code=True
),
"NomicBertModel": _HfExamplesInfo(
"nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True
),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
"Qwen2ForRewardModel": _HfExamplesInfo(
"Qwen/Qwen2.5-Math-RM-72B",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
),
"Qwen2ForProcessRewardModel": _HfExamplesInfo(
"Qwen/Qwen2.5-Math-PRM-7B",
max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
),
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"),
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
# [Multimodal]
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model
# going OOM in CI
max_num_seqs=32,
),
"Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model going OOM in CI
max_num_seqs=32,
),
"Phi3VForCausalLM": _HfExamplesInfo(
"TIGER-Lab/VLM2Vec-Full", trust_remote_code=True
),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"),
"PrithviGeoSpatialMAE": _HfExamplesInfo(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model
# going OOM in CI
max_num_seqs=32,
),
"Terratorch": _HfExamplesInfo(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model going OOM in CI
max_num_seqs=32,
),
}
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
# [Decoder-only]
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
"GPT2ForSequenceClassification": _HfExamplesInfo(
"nie3e/sentiment-polish-gpt2-small"
),
# [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"BertForSequenceClassification": _HfExamplesInfo(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
),
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True,
hf_overrides={
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-multilingual-reranker-base",
trust_remote_code=True,
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
),
"ModernBertForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-reranker-modernbert-base"
),
"RobertaForSequenceClassification": _HfExamplesInfo(
"cross-encoder/quora-roberta-base"
),
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"),
}
_AUTOMATIC_CONVERTED_MODELS = {
# Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"GemmaForSequenceClassification": _HfExamplesInfo(
"BAAI/bge-reranker-v2-gemma",
hf_overrides={
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
},
),
"LlamaForSequenceClassification": _HfExamplesInfo(
"Skywork/Skywork-Reward-V2-Llama-3.2-1B"
),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"),
"Qwen3ForSequenceClassification": _HfExamplesInfo(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
),
}
_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"Cohere2VisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/command-a-vision-07-2025"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"DotsOCRForCausalLM": _HfExamplesInfo("rednote-hilab/dots.ocr",
trust_remote_code=True),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"),
"Blip2ForConditionalGeneration": _HfExamplesInfo(
"Salesforce/blip2-opt-2.7b",
extras={"6b": "Salesforce/blip2-opt-6.7b"},
),
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"),
"Cohere2VisionForConditionalGeneration": _HfExamplesInfo(
"CohereLabs/command-a-vision-07-2025"
),
"DeepseekVLV2ForCausalLM": _HfExamplesInfo(
"deepseek-ai/deepseek-vl2-tiny",
extras={"fork": "Isotr0py/deepseek-vl2-tiny"},
max_transformers_version="4.48",
transformers_version_reason="HF model is not compatible.",
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
),
"DotsOCRForCausalLM": _HfExamplesInfo(
"rednote-hilab/dots.ocr", trust_remote_code=True
),
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
trust_remote_code=True),
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo(
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
trust_remote_code=True,
),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
min_transformers_version="4.53"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V",
min_transformers_version="4.56"), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
"HCXVisionForCausalLM": _HfExamplesInfo("naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", # noqa: E501
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, # noqa: E501
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55"), # noqa: E501
"InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1",
trust_remote_code=True), # noqa: E501
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
extras={"2B": "OpenGVLab/InternVL2-2B",
"3.0": "OpenGVLab/InternVL3-1B", # noqa: E501
"3.5-qwen3": "OpenGVLab/InternVL3_5-1B", # noqa: E501
"3.5-qwen3moe": "OpenGVLab/InternVL3_5-30B-A3B", # noqa: E501
"3.5-gptoss": "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview"}, # noqa: E501
trust_remote_code=True),
"InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501
trust_remote_code=True),
"KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-1_5-8B", # noqa: E501
trust_remote_code=True),
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
max_model_len=10240,
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501
),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
"MiDashengLMModel": _HfExamplesInfo("mispeech/midashenglm-7b",
trust_remote_code=True),
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
trust_remote_code=True),
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501
trust_remote_code=True),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
trust_remote_code=True,
v0_only=True),
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
max_transformers_version="4.48",
transformers_version_reason="Incorrectly-detected `tensorflow` import.", # noqa: E501
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True),
"Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501
trust_remote_code=True),
"NemotronH_Nano_VL_V2": _HfExamplesInfo("nano_vl_dummy",
is_available_online=False,
trust_remote_code=True),
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
max_transformers_version="4.53",
transformers_version_reason="HF model is not compatible", # noqa: E501
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
trust_remote_code=True),
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True,
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501
revision="refs/pr/70"),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral"),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
trust_remote_code=True,
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
max_model_len=4096),
"Gemma3nForConditionalGeneration": _HfExamplesInfo(
"google/gemma-3n-E2B-it",
min_transformers_version="4.53",
),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo(
"ibm-granite/granite-speech-3.3-2b"
),
"GLM4VForCausalLM": _HfExamplesInfo(
"zai-org/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
),
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"),
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-4.5V", min_transformers_version="4.56"
),
"H2OVLChatModel": _HfExamplesInfo(
"h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
extras={"2b": "h2oai/h2ovl-mississippi-2b"},
max_transformers_version="4.48",
transformers_version_reason="HF model is not compatible.",
),
"HCXVisionForCausalLM": _HfExamplesInfo(
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
trust_remote_code=True,
),
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceM4/Idefics3-8B-Llama3",
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"},
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55",
),
"InternS1ForConditionalGeneration": _HfExamplesInfo(
"internlm/Intern-S1", trust_remote_code=True
),
"InternVLChatModel": _HfExamplesInfo(
"OpenGVLab/InternVL2-1B",
extras={
"2B": "OpenGVLab/InternVL2-2B",
"3.0": "OpenGVLab/InternVL3-1B",
"3.5-qwen3": "OpenGVLab/InternVL3_5-1B",
"3.5-qwen3moe": "OpenGVLab/InternVL3_5-30B-A3B",
"3.5-gptoss": "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview",
},
trust_remote_code=True,
),
"InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
"KeyeForConditionalGeneration": _HfExamplesInfo(
"Kwai-Keye/Keye-VL-8B-Preview",
trust_remote_code=True,
),
"KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo(
"Kwai-Keye/Keye-VL-1_5-8B",
trust_remote_code=True,
),
"KimiVLForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-VL-A3B-Instruct",
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
trust_remote_code=True,
),
"Llama4ForConditionalGeneration": _HfExamplesInfo(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
max_model_len=10240,
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"},
),
"LlavaForConditionalGeneration": _HfExamplesInfo(
"llava-hf/llava-1.5-7b-hf",
extras={
"mistral": "mistral-community/pixtral-12b",
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic",
},
),
"LlavaNextForConditionalGeneration": _HfExamplesInfo(
"llava-hf/llava-v1.6-mistral-7b-hf"
),
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo(
"llava-hf/LLaVA-NeXT-Video-7B-hf"
),
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
),
"MantisForConditionalGeneration": _HfExamplesInfo(
"TIGER-Lab/Mantis-8B-siglip-llama3",
max_transformers_version="4.48",
transformers_version_reason="HF model is not compatible.",
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
),
"MiDashengLMModel": _HfExamplesInfo(
"mispeech/midashenglm-7b", trust_remote_code=True
),
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True),
"MiniCPMV": _HfExamplesInfo(
"openbmb/MiniCPM-Llama3-V-2_5",
extras={
"2.6": "openbmb/MiniCPM-V-2_6",
"4.0": "openbmb/MiniCPM-V-4",
"4.5": "openbmb/MiniCPM-V-4_5",
},
trust_remote_code=True,
),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo(
"MiniMaxAI/MiniMax-VL-01",
trust_remote_code=True,
v0_only=True,
),
"Mistral3ForConditionalGeneration": _HfExamplesInfo(
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"},
),
"MolmoForCausalLM": _HfExamplesInfo(
"allenai/Molmo-7B-D-0924",
max_transformers_version="4.48",
transformers_version_reason="Incorrectly-detected `tensorflow` import.",
extras={"olmo": "allenai/Molmo-7B-O-0924"},
trust_remote_code=True,
),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True),
"Llama_Nemotron_Nano_VL": _HfExamplesInfo(
"nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1",
trust_remote_code=True,
),
"NemotronH_Nano_VL_V2": _HfExamplesInfo(
"nano_vl_dummy", is_available_online=False, trust_remote_code=True
),
"Ovis": _HfExamplesInfo(
"AIDC-AI/Ovis2-1B",
trust_remote_code=True,
max_transformers_version="4.53",
transformers_version_reason="HF model is not compatible",
extras={
"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B",
},
),
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", trust_remote_code=True),
"PaliGemmaForConditionalGeneration": _HfExamplesInfo(
"google/paligemma-3b-mix-224",
extras={"v2": "google/paligemma2-3b-ft-docci-448"},
),
"Phi3VForCausalLM": _HfExamplesInfo(
"microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True,
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"},
),
"Phi4MMForCausalLM": _HfExamplesInfo(
"microsoft/Phi-4-multimodal-instruct", trust_remote_code=True
),
"Phi4MultimodalForCausalLM": _HfExamplesInfo(
"microsoft/Phi-4-multimodal-instruct",
revision="refs/pr/70",
),
"PixtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Pixtral-12B-2409",
tokenizer_mode="mistral",
),
"QwenVLForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen-VL",
extras={"chat": "Qwen/Qwen-VL-Chat"},
trust_remote_code=True,
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
),
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen2-Audio-7B-Instruct"
),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"),
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen2.5-VL-3B-Instruct",
max_model_len=4096,
),
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
"Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False),
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
trust_remote_code=True),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55"), # noqa: E501
"Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
trust_remote_code=True),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501
"Tarsier2ForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier2-Recap-7b", # noqa: E501
hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"),
"Qwen3VLForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3-VL-4B-Instruct",
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False,
),
"Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct",
max_model_len=4096,
min_transformers_version="4.57",
is_available_online=False,
),
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo(
"Skywork/Skywork-R1V-38B", trust_remote_code=True
),
"SmolVLMForConditionalGeneration": _HfExamplesInfo(
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
min_transformers_version="4.56",
transformers_version_reason="HF model broken in 4.55",
),
"Step3VLForConditionalGeneration": _HfExamplesInfo(
"stepfun-ai/step3", trust_remote_code=True
),
"UltravoxModel": _HfExamplesInfo(
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
trust_remote_code=True,
),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"),
"Tarsier2ForConditionalGeneration": _HfExamplesInfo(
"omni-research/Tarsier2-Recap-7b",
hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]},
),
"VoxtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Voxtral-Mini-3B-2507",
min_transformers_version="4.54",
......@@ -600,80 +795,120 @@ _MULTIMODAL_EXAMPLE_MODELS = {
is_available_online=False,
),
# [Encoder-decoder]
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"),
# [Cross-encoder]
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"),
}
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"MedusaModel": _HfExamplesInfo("JackFram/llama-68m",
speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501
"MedusaModel": _HfExamplesInfo(
"JackFram/llama-68m", speculative_model="abhigoyal/vllm-medusa-llama-68m-random"
),
# Temporarily disabled.
# TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m",
# speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
trust_remote_code=True),
"EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random",
speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501
trust_remote_code=True),
"EagleLlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B-Instruct", # noqa: E501
trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
"Eagle3LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.1-8B-Instruct", # noqa: E501
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
tokenizer="meta-llama/Llama-3.1-8B-Instruct",
use_original_num_layers=True,
max_model_len=10240),
"LlamaForCausalLMEagle3": _HfExamplesInfo("Qwen/Qwen3-8B", # noqa: E501
trust_remote_code=True,
speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
tokenizer="Qwen/Qwen3-8B",
use_original_num_layers=True),
# "MLPSpeculatorPreTrainedModel": _HfExamplesInfo(
# "JackFram/llama-160m",
# speculative_model="ibm-ai-platform/llama-160m-accelerator"
# ),
"DeepSeekMTPModel": _HfExamplesInfo(
"luccafong/deepseek_mtp_main_random",
speculative_model="luccafong/deepseek_mtp_draft_random",
trust_remote_code=True,
),
"EagleDeepSeekMTPModel": _HfExamplesInfo(
"eagle618/deepseek-v3-random",
speculative_model="eagle618/eagle-deepseek-v3-random",
trust_remote_code=True,
),
"EagleLlamaForCausalLM": _HfExamplesInfo(
"meta-llama/Meta-Llama-3-8B-Instruct",
trust_remote_code=True,
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
),
"Eagle3LlamaForCausalLM": _HfExamplesInfo(
"meta-llama/Llama-3.1-8B-Instruct",
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="meta-llama/Llama-3.1-8B-Instruct",
use_original_num_layers=True,
max_model_len=10240,
),
"LlamaForCausalLMEagle3": _HfExamplesInfo(
"Qwen/Qwen3-8B",
trust_remote_code=True,
speculative_model="AngelSlim/Qwen3-8B_eagle3",
tokenizer="Qwen/Qwen3-8B",
use_original_num_layers=True,
),
"EagleLlama4ForCausalLM": _HfExamplesInfo(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
trust_remote_code=True,
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True,
is_available_online=False,
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
"ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
trust_remote_code=True,
speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"),
"Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5",
min_transformers_version="4.56",
is_available_online=False),
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct",
),
"EagleMiniCPMForCausalLM": _HfExamplesInfo(
"openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True,
is_available_online=False,
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
tokenizer="openbmb/MiniCPM-2B-sft-bf16",
),
"ErnieMTPModel": _HfExamplesInfo(
"baidu/ERNIE-4.5-21B-A3B-PT",
trust_remote_code=True,
speculative_model="baidu/ERNIE-4.5-21B-A3B-PT",
),
"Glm4MoeMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5",
min_transformers_version="4.56",
is_available_online=False,
),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
trust_remote_code=True,
speculative_model="meituan-longcat/LongCat-Flash-Chat"),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
speculative_model="meituan-longcat/LongCat-Flash-Chat",
),
"MiMoMTPModel": _HfExamplesInfo(
"XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL",
),
"Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo(
"Qwen/Qwen2.5-VL-7B-Instruct",
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.3"),
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl",
),
"Qwen3NextMTP": _HfExamplesInfo(
"Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3"
),
}
_TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersEmbeddingModel": _HfExamplesInfo(
"BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"
),
"TransformersForSequenceClassification": _HfExamplesInfo(
"papluca/xlm-roberta-base-language-detection",
min_transformers_version="4.57.0.dev0",
),
"TransformersForCausalLM": _HfExamplesInfo(
"hmellor/Ilama-3.2-1B", trust_remote_code=True
),
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForCausalLM": _HfExamplesInfo(
"allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"
),
"TransformersMoEForMultimodalLM": _HfExamplesInfo(
"Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"
),
"TransformersMoEEmbeddingModel": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
),
"TransformersMoEForSequenceClassification": _HfExamplesInfo(
"Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"
),
}
_EXAMPLE_MODELS = {
......@@ -699,8 +934,9 @@ class HfExampleModels:
try:
return self.hf_models[model_arch]
except KeyError:
raise ValueError(f"No example model defined for {model_arch}; "
f"please update this file.") from None
raise ValueError(
f"No example model defined for {model_arch}; please update this file."
) from None
def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
for info in self.hf_models.values():
......@@ -712,8 +948,9 @@ class HfExampleModels:
if any(extra == model_id for extra in info.extras.values()):
return info
raise ValueError(f"No example model defined for {model_id}; "
f"please update this file.")
raise ValueError(
f"No example model defined for {model_id}; please update this file."
)
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
......
......@@ -71,25 +71,27 @@ def _dummy_items(
)
# yapf: disable
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501
(
_dummy_items(
{"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}
).get_data(),
460,
), # noqa: E501
],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
cache = MultiModalCache.get_lru_cache(2048, type(item))
cache[""] = item
assert cache.currsize == expected_size
prompt_update = PromptInsertion("dummy", "target", "insertion") \
.resolve(0)
prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0)
cache[""] = MultiModalProcessorCacheItem(item, [prompt_update])
assert cache.currsize == expected_size
......@@ -106,9 +108,9 @@ def _create_vllm_config(
return VllmConfig(
model_config=ModelConfig(
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
mm_processor_cache_gb=mm_processor_cache_gb),
parallel_config=ParallelConfig(
data_parallel_size=1 if enable_ipc else 2),
mm_processor_cache_gb=mm_processor_cache_gb,
),
parallel_config=ParallelConfig(data_parallel_size=1 if enable_ipc else 2),
)
......@@ -124,11 +126,9 @@ def _compare_caches(
seed: int = 0,
):
cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY)
cache_0_p1 = engine_receiver_cache_from_config(config_0,
MULTIMODAL_REGISTRY)
cache_0_p1 = engine_receiver_cache_from_config(config_0, MULTIMODAL_REGISTRY)
cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY)
cache_1_p1 = engine_receiver_cache_from_config(config_1,
MULTIMODAL_REGISTRY)
cache_1_p1 = engine_receiver_cache_from_config(config_1, MULTIMODAL_REGISTRY)
cache_size_gb = max(
config_0.model_config.multimodal_config.mm_processor_cache_gb,
......@@ -142,8 +142,7 @@ def _compare_caches(
for _ in range(int(item_capacity / hit_rate))
]
all_hashes = [
MultiModalHasher.hash_kwargs(item=item.get_data())
for item in all_items
MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items
]
# Should not be used since there is nothing to convert to text
......@@ -162,7 +161,8 @@ def _compare_caches(
for _ in range(is_cached_calls_per_iter):
cache_0_p0.is_cached(selected_hashes)
cache_0_p0_out = [
item for item, _ in cache_0_p0.get_and_update(
item
for item, _ in cache_0_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
......@@ -174,7 +174,8 @@ def _compare_caches(
for _ in range(is_cached_calls_per_iter):
cache_1_p0.is_cached(selected_hashes)
cache_1_p0_out = [
item for item, _ in cache_1_p0.get_and_update(
item
for item, _ in cache_1_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
......@@ -183,14 +184,12 @@ def _compare_caches(
if cache_0_p1 is None:
cache_0_p1_out = cache_0_p0_out
else:
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out,
selected_hashes)
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, selected_hashes)
if cache_1_p1 is None:
cache_1_p1_out = cache_1_p0_out
else:
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out,
selected_hashes)
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, selected_hashes)
assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}"
......
......@@ -9,9 +9,6 @@ import pytest
from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (
InputProcessingContext,
PlaceholderFeaturesInfo,
......@@ -24,8 +21,6 @@ from vllm.multimodal.processing import (
iter_token_matches,
replace_token_matches,
)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -34,7 +29,6 @@ from .utils import random_image
pytestmark = pytest.mark.cpu_test
# yapf: disable
@pytest.mark.parametrize(
("token_ids", "match_ids", "expected"),
[
......@@ -44,34 +38,34 @@ pytestmark = pytest.mark.cpu_test
[32000, 32000, 32000],
[32000],
[
{ "start_idx": 0, "end_idx": 1 },
{ "start_idx": 1, "end_idx": 2 },
{ "start_idx": 2, "end_idx": 3 },
{"start_idx": 0, "end_idx": 1},
{"start_idx": 1, "end_idx": 2},
{"start_idx": 2, "end_idx": 3},
],
),
(
[32000, 32000, 32000],
[32000, 32000],
[{ "start_idx": 0, "end_idx": 2 }],
[{"start_idx": 0, "end_idx": 2}],
),
(
[32000, 32000, 32000],
[32000, 32000, 32000],
[{ "start_idx": 0, "end_idx": 3 }],
[{"start_idx": 0, "end_idx": 3}],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[28747, 32000],
[
{ "start_idx": 1, "end_idx": 3 },
{ "start_idx": 6, "end_idx": 8 },
{"start_idx": 1, "end_idx": 3},
{"start_idx": 6, "end_idx": 8},
],
),
(
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[28747, 32000, 32000, 32000],
[
{ "start_idx": 1, "end_idx": 5 },
{"start_idx": 1, "end_idx": 5},
],
),
(
......@@ -82,14 +76,13 @@ pytestmark = pytest.mark.cpu_test
],
)
@pytest.mark.parametrize("start_idx", [0, 4, 8])
# yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
result = list(iter_token_matches(token_ids, match_ids,
start_idx=start_idx))
result = list(iter_token_matches(token_ids, match_ids, start_idx=start_idx))
# Manually constructed results
assert [item._asdict() for item in result
] == [item for item in expected if item["start_idx"] >= start_idx]
assert [item._asdict() for item in result] == [
item for item in expected if item["start_idx"] >= start_idx
]
# Invariants
match_lens = [end - start for start, end in result]
......@@ -97,7 +90,6 @@ def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
assert all(match_len == len(match_ids) for match_len in match_lens)
# yapf: disable
@pytest.mark.parametrize(
("token_ids", "match_ids", "new_ids", "expected"),
[
......@@ -141,7 +133,6 @@ def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
),
],
)
# yapf: enable
def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
result = replace_token_matches(token_ids, match_ids, new_ids)
......@@ -149,7 +140,6 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
assert result == expected
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "expected_by_key"),
[
......@@ -166,11 +156,11 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
"pattern_1": [],
"pattern_2": [],
"pattern_3": [
{ "start_idx": 0, "end_idx": 0 },
{"start_idx": 0, "end_idx": 0},
],
"pattern_4": [],
"pattern_5": [
{ "start_idx": 0, "end_idx": 0 },
{"start_idx": 0, "end_idx": 0},
],
},
),
......@@ -186,26 +176,26 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
},
{
"pattern_1": [
{ "start_idx": 0, "end_idx": 1 },
{ "start_idx": 1, "end_idx": 2 },
{ "start_idx": 2, "end_idx": 3 },
{ "start_idx": 3, "end_idx": 4 },
{"start_idx": 0, "end_idx": 1},
{"start_idx": 1, "end_idx": 2},
{"start_idx": 2, "end_idx": 3},
{"start_idx": 3, "end_idx": 4},
],
"pattern_2": [
{ "start_idx": 0, "end_idx": 2 },
{ "start_idx": 2, "end_idx": 4 },
{"start_idx": 0, "end_idx": 2},
{"start_idx": 2, "end_idx": 4},
],
"pattern_3": [
{ "start_idx": 0, "end_idx": 3 },
{"start_idx": 0, "end_idx": 3},
],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
{"start_idx": 0, "end_idx": 0},
],
"pattern_5": [
{ "start_idx": 1, "end_idx": 1 },
{"start_idx": 1, "end_idx": 1},
],
"pattern_6": [
{ "start_idx": 4, "end_idx": 4 },
{"start_idx": 4, "end_idx": 4},
],
},
),
......@@ -221,26 +211,25 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
},
{
"pattern_1": [
{ "start_idx": 1, "end_idx": 3 },
{ "start_idx": 6, "end_idx": 8 },
{"start_idx": 1, "end_idx": 3},
{"start_idx": 6, "end_idx": 8},
],
"pattern_2": [
{ "start_idx": 1, "end_idx": 5 },
{"start_idx": 1, "end_idx": 5},
],
"pattern_3": [],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
{"start_idx": 0, "end_idx": 0},
],
"pattern_5": [],
"pattern_6": [
{ "start_idx": 10, "end_idx": 10 },
{"start_idx": 10, "end_idx": 10},
],
},
),
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_token_matches(
prompt,
target_by_key,
......@@ -272,7 +261,6 @@ def test_find_token_matches(
} == expected_by_key
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "expected_by_key"),
[
......@@ -288,16 +276,16 @@ def test_find_token_matches(
"pattern_5": PromptIndexTargets.end(),
},
{
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_1": [{"start_idx": 0, "end_idx": 0}],
"pattern_2": [],
"pattern_3": [
{ "start_idx": 0, "end_idx": 0 },
{"start_idx": 0, "end_idx": 0},
],
"pattern_4": [],
"pattern_5": [
{ "start_idx": 0, "end_idx": 0 },
{"start_idx": 0, "end_idx": 0},
],
}
},
),
(
"<image><image><image><image>",
......@@ -311,26 +299,26 @@ def test_find_token_matches(
},
{
"pattern_1": [
{ "start_idx": 0, "end_idx": 7 },
{ "start_idx": 7, "end_idx": 14 },
{ "start_idx": 14, "end_idx": 21 },
{ "start_idx": 21, "end_idx": 28 },
{"start_idx": 0, "end_idx": 7},
{"start_idx": 7, "end_idx": 14},
{"start_idx": 14, "end_idx": 21},
{"start_idx": 21, "end_idx": 28},
],
"pattern_2": [
{ "start_idx": 0, "end_idx": 14 },
{ "start_idx": 14, "end_idx": 28 },
{"start_idx": 0, "end_idx": 14},
{"start_idx": 14, "end_idx": 28},
],
"pattern_3": [
{ "start_idx": 0, "end_idx": 21 },
{"start_idx": 0, "end_idx": 21},
],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
{"start_idx": 0, "end_idx": 0},
],
"pattern_5": [
{ "start_idx": 7, "end_idx": 7 },
{"start_idx": 7, "end_idx": 7},
],
"pattern_6": [
{ "start_idx": 28, "end_idx": 28 },
{"start_idx": 28, "end_idx": 28},
],
},
),
......@@ -346,21 +334,21 @@ def test_find_token_matches(
},
{
"pattern_1": [
{ "start_idx": 0, "end_idx": 13 },
{ "start_idx": 27, "end_idx": 40 },
{"start_idx": 0, "end_idx": 13},
{"start_idx": 27, "end_idx": 40},
],
"pattern_2": [
{ "start_idx": 0, "end_idx": 27 },
{"start_idx": 0, "end_idx": 27},
],
"pattern_3": [],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
{"start_idx": 0, "end_idx": 0},
],
"pattern_5": [
{ "start_idx": 13, "end_idx": 13 },
{"start_idx": 13, "end_idx": 13},
],
"pattern_6": [
{ "start_idx": 48, "end_idx": 48 },
{"start_idx": 48, "end_idx": 48},
],
},
),
......@@ -374,22 +362,21 @@ def test_find_token_matches(
},
{
"pattern_1": [
{ "start_idx": 0, "end_idx": 9 },
{ "start_idx": 16, "end_idx": 25 },
{"start_idx": 0, "end_idx": 9},
{"start_idx": 16, "end_idx": 25},
],
"pattern_2": [
{ "start_idx": 0, "end_idx": 16 },
{ "start_idx": 16, "end_idx": 32 },
{"start_idx": 0, "end_idx": 16},
{"start_idx": 16, "end_idx": 32},
],
"pattern_3": [
{ "start_idx": 0, "end_idx": 25 },
{"start_idx": 0, "end_idx": 25},
],
},
),
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_text_matches(
prompt,
target_by_key,
......@@ -421,7 +408,6 @@ def test_find_text_matches(
} == expected_by_key
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
......@@ -549,9 +535,8 @@ def test_find_text_matches(
},
},
),
]
],
)
# yapf: enable
def test_find_update_text(
prompt,
target_by_key,
......@@ -562,13 +547,15 @@ def test_find_update_text(
mock_tokenizer = cast(AnyTokenizer, object())
for (
update_type,
expected_by_mm_count,
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
for mm_count, expected in expected_by_mm_count.items():
mm_prompt_updates = {
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
key: [
[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)
]
for key, target in target_by_key.items()
}
......@@ -589,7 +576,6 @@ def test_find_update_text(
assert new_prompt == expected
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
......@@ -615,8 +601,43 @@ def test_find_update_text(
{
PromptInsertion: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501
2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501
1: [
1,
9833,
28747,
32000,
32000,
32000,
9833,
28747,
32000,
32000,
918,
1550,
918,
1550,
], # noqa: E501
2: [
1,
9833,
28747,
32000,
32000,
32000,
32000,
32000,
9833,
28747,
32000,
32000,
918,
1550,
918,
1550,
1550,
918,
1550,
], # noqa: E501
},
PromptReplacement: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
......@@ -719,9 +740,8 @@ def test_find_update_text(
},
},
),
]
],
)
# yapf: enable
def test_find_update_tokens(
prompt,
target_by_key,
......@@ -732,13 +752,15 @@ def test_find_update_tokens(
mock_tokenizer = cast(AnyTokenizer, object())
for (
update_type,
expected_by_mm_count,
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
for mm_count, expected in expected_by_mm_count.items():
mm_prompt_updates = {
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
key: [
[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)
]
for key, target in target_by_key.items()
}
......@@ -759,7 +781,6 @@ def test_find_update_tokens(
assert new_prompt == expected
# yapf: disable
@pytest.mark.parametrize(
"repl_by_key",
[
......@@ -796,8 +817,7 @@ def test_find_update_tokens(
is_embed=None,
),
],
}
},
),
(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
......@@ -828,7 +848,7 @@ def test_find_update_tokens(
),
],
# No match for pattern_4 as it has lower priority than pattern_1
}
},
),
(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
......@@ -867,12 +887,11 @@ def test_find_update_tokens(
is_embed=None,
),
],
}
},
),
]
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_mm_placeholders(
repl_by_key,
prompt,
......@@ -899,8 +918,15 @@ def test_find_mm_placeholders(
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("limit", "num_supported", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
[
(0, 0, True),
(0, 1, True),
(1, 0, False),
(1, 1, True),
(1, 2, True),
(2, 1, False),
(2, 2, True),
],
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
limit_mm_per_prompt = {"image": limit}
......@@ -930,8 +956,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
[
(0, 0, True),
(0, 1, True),
(1, 0, False),
(1, 1, True),
(1, 2, True),
(2, 1, False),
(2, 2, True),
],
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
limit_mm_per_prompt = {"image": limit}
......@@ -966,7 +999,6 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
class DummyProcessor:
def __init__(self, a: int = 0, b: int = 0) -> None:
super().__init__()
......@@ -982,7 +1014,6 @@ class DummyProcessor:
return dict(a=a, c=c)
# yapf: disable
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy
@pytest.mark.parametrize(
("config_kwargs", "inference_kwargs", "expected_kwargs"),
......@@ -996,7 +1027,6 @@ class DummyProcessor:
({"b": 1, "c": 1}, {}, {"a": 0, "b": 1}),
],
)
# yapf: enable
def test_hf_processor_init_kwargs(
model_id,
config_kwargs,
......@@ -1020,7 +1050,6 @@ def test_hf_processor_init_kwargs(
assert getattr(processor, k) == v
# yapf: disable
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy
@pytest.mark.parametrize(
("config_kwargs", "inference_kwargs", "expected_kwargs"),
......@@ -1034,7 +1063,6 @@ def test_hf_processor_init_kwargs(
({"b": 1, "c": 1}, {}, {"a": 0, "c": 1}),
],
)
# yapf: enable
def test_hf_processor_call_kwargs(
model_id,
config_kwargs,
......
......@@ -233,7 +233,6 @@ async def test_fetch_video_http_with_dynamic_loader(
assert metadata_sync["video_backend"] == "opencv_dynamic"
# yapf: disable
@pytest.mark.parametrize(
"case",
[
......@@ -264,7 +263,6 @@ async def test_fetch_video_http_with_dynamic_loader(
("image", 0),
],
),
# Two modalities
## Internally sorted
dict(
......@@ -276,7 +274,7 @@ async def test_fetch_video_http_with_dynamic_loader(
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
]
],
},
expected_modality_idxs=[
("audio", 0),
......@@ -295,7 +293,7 @@ async def test_fetch_video_http_with_dynamic_loader(
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
],
},
expected_modality_idxs=[
("image", 0),
......@@ -314,7 +312,7 @@ async def test_fetch_video_http_with_dynamic_loader(
"audio": [
PlaceholderRange(offset=11, length=4),
PlaceholderRange(offset=5, length=2),
]
],
},
expected_modality_idxs=[
("image", 1),
......@@ -323,7 +321,6 @@ async def test_fetch_video_http_with_dynamic_loader(
("audio", 0),
],
),
# Three modalities
## Internally sorted
dict(
......@@ -339,7 +336,7 @@ async def test_fetch_video_http_with_dynamic_loader(
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
]
],
},
expected_modality_idxs=[
("audio", 0),
......@@ -363,7 +360,7 @@ async def test_fetch_video_http_with_dynamic_loader(
],
"video": [
PlaceholderRange(offset=8, length=5),
]
],
},
expected_modality_idxs=[
("image", 0),
......@@ -386,7 +383,7 @@ async def test_fetch_video_http_with_dynamic_loader(
],
"video": [
PlaceholderRange(offset=8, length=5),
]
],
},
expected_modality_idxs=[
("image", 0),
......@@ -398,7 +395,6 @@ async def test_fetch_video_http_with_dynamic_loader(
),
],
)
# yapf: enable
def test_argsort_mm_positions(case):
mm_positions = case["mm_positions"]
expected_modality_idxs = case["expected_modality_idxs"]
......@@ -413,13 +409,16 @@ def test_argsort_mm_positions(case):
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_allowed_media_domains(video_url: str, num_frames: int):
connector = MediaConnector(
media_io_kwargs={"video": {
"num_frames": num_frames,
}},
media_io_kwargs={
"video": {
"num_frames": num_frames,
}
},
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
],
)
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
......
......@@ -59,48 +59,52 @@ def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
)
# yapf: disable
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
])
# yapf: enable
@pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs",
[
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
],
)
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
mm_processor_kwargs)
encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
zipped_prompts = zip_enc_dec_prompts(
encoder_prompts, decoder_prompts, mm_processor_kwargs
)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
expected_mm_kwargs,
zipped_prompts):
for enc, dec, exp_kwargs, zipped in zip(
encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped['encoder_prompt'] == enc
assert zipped['decoder_prompt'] == dec
assert zipped['mm_processor_kwargs'] == exp_kwargs
@pytest.mark.parametrize("model_id", [
"facebook/opt-125m",
])
@pytest.mark.parametrize("prompt", [
{
"prompt": "",
"multi_modal_data": {
"dummy": []
assert zipped["encoder_prompt"] == enc
assert zipped["decoder_prompt"] == dec
assert zipped["mm_processor_kwargs"] == exp_kwargs
@pytest.mark.parametrize(
"model_id",
[
"facebook/opt-125m",
],
)
@pytest.mark.parametrize(
"prompt",
[
{
"prompt": "",
"multi_modal_data": {"dummy": []},
},
},
{
"prompt_token_ids": [],
"multi_modal_data": {
"dummy": []
{
"prompt_token_ids": [],
"multi_modal_data": {"dummy": []},
},
},
])
],
)
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
model_config = ModelConfig(model=model_id)
tokenizer = init_tokenizer_from_configs(model_config)
......@@ -110,15 +114,19 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt):
input_preprocessor.preprocess(prompt)
@pytest.mark.parametrize("model_id", [
"facebook/chameleon-7b",
])
@pytest.mark.parametrize("prompt", [
"",
{
"prompt_token_ids": []
},
])
@pytest.mark.parametrize(
"model_id",
[
"facebook/chameleon-7b",
],
)
@pytest.mark.parametrize(
"prompt",
[
"",
{"prompt_token_ids": []},
],
)
def test_preprocessor_always_mm_code_path(model_id, prompt):
model_config = ModelConfig(model=model_id)
tokenizer = init_tokenizer_from_configs(model_config)
......
......@@ -9,14 +9,10 @@ import pytest
import torch
import torch_xla
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe as pallas_moe
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as torch_moe,
)
# yapf: enable
from vllm.platforms import current_platform
if not current_platform.is_tpu():
......
......@@ -388,7 +388,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
assert "-O.level" in caplog_vllm.text
# yapf: enable
@pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
[
......@@ -408,7 +407,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
(lambda foo, **kwargs: None, "foo", True, True, False),
],
)
# yapf: disable
def test_supports_kw(
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
):
......@@ -681,7 +679,6 @@ def test_lru_cache():
assert 6 in cache
# yapf: disable
@pytest.mark.parametrize(
("src_dtype", "tgt_dtype", "expected_result"),
[
......@@ -715,12 +712,10 @@ def test_lru_cache():
(torch.complex64, torch.complex32, False),
],
)
# yapf: enable
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
# yapf: disable
@pytest.mark.parametrize(
("dtypes", "expected_result"),
[
......@@ -730,7 +725,6 @@ def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
],
)
# yapf: enable
def test_common_broadcastable_dtype(dtypes, expected_result):
assert common_broadcastable_dtype(dtypes) == expected_result
......@@ -775,7 +769,6 @@ def test_placeholder_module_error_handling():
_ = placeholder_attr.module
# yapf: disable
@pytest.mark.parametrize(
"obj,key1,key2",
[
......@@ -785,8 +778,8 @@ def test_placeholder_module_error_handling():
({1: "a", 2: "b"}, 1, 3),
# Tests for both keys do not exist
({1: "a", 2: "b"}, 3, 4),
])
# yapf: enable
],
)
def test_swap_dict_values(obj, key1, key2):
original_obj = obj.copy()
swap_dict_values(obj, key1, key2)
......@@ -800,26 +793,30 @@ def test_swap_dict_values(obj, key1, key2):
assert key1 not in obj
def test_model_specification(parser_with_config, cli_config_file,
cli_config_file_with_model):
def test_model_specification(
parser_with_config, cli_config_file, cli_config_file_with_model
):
# Test model in CLI takes precedence over config
args = parser_with_config.parse_args(
['serve', 'cli-model', '--config', cli_config_file_with_model])
assert args.model_tag == 'cli-model'
assert args.served_model_name == 'mymodel'
["serve", "cli-model", "--config", cli_config_file_with_model]
)
assert args.model_tag == "cli-model"
assert args.served_model_name == "mymodel"
# Test model from config file works
args = parser_with_config.parse_args([
'serve',
'--config',
cli_config_file_with_model,
])
assert args.model == 'config-model'
assert args.served_model_name == 'mymodel'
args = parser_with_config.parse_args(
[
"serve",
"--config",
cli_config_file_with_model,
]
)
assert args.model == "config-model"
assert args.served_model_name == "mymodel"
# Test no model specified anywhere raises error
with pytest.raises(ValueError, match="No model specified!"):
parser_with_config.parse_args(['serve', '--config', cli_config_file])
parser_with_config.parse_args(["serve", "--config", cli_config_file])
# Test using --model option raises error
# with pytest.raises(
......@@ -833,47 +830,52 @@ def test_model_specification(parser_with_config, cli_config_file,
# Test using --model option back-compatibility
# (when back-compatibility ends, the above test should be uncommented
# and the below test should be removed)
args = parser_with_config.parse_args([
'serve',
'--tensor-parallel-size',
'2',
'--model',
'my-model',
'--trust-remote-code',
'--port',
'8001',
])
args = parser_with_config.parse_args(
[
"serve",
"--tensor-parallel-size",
"2",
"--model",
"my-model",
"--trust-remote-code",
"--port",
"8001",
]
)
assert args.model is None
assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True
assert args.port == 8001
args = parser_with_config.parse_args([
'serve',
'--tensor-parallel-size=2',
'--model=my-model',
'--trust-remote-code',
'--port=8001',
])
args = parser_with_config.parse_args(
[
"serve",
"--tensor-parallel-size=2",
"--model=my-model",
"--trust-remote-code",
"--port=8001",
]
)
assert args.model is None
assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True
assert args.port == 8001
# Test other config values are preserved
args = parser_with_config.parse_args([
'serve',
'cli-model',
'--config',
cli_config_file_with_model,
])
args = parser_with_config.parse_args(
[
"serve",
"cli-model",
"--config",
cli_config_file_with_model,
]
)
assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True
assert args.port == 12312
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])])
@pytest.mark.parametrize("input", [(), ("abc",), (None,), (None, bool, [1, 2, 3])])
def test_sha256(input: tuple):
digest = sha256(input)
assert digest is not None
......@@ -887,7 +889,7 @@ def test_sha256(input: tuple):
assert digest == sha256(input)
# hashing different input, returns different value
assert digest != sha256(input + (1, ))
assert digest != sha256(input + (1,))
@pytest.mark.parametrize(
......@@ -897,7 +899,8 @@ def test_sha256(input: tuple):
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
("inproc://some_identifier", ("inproc", "some_identifier", "")),
])
],
)
def test_split_zmq_path(path, expected):
assert split_zmq_path(path) == expected
......@@ -909,7 +912,8 @@ def test_split_zmq_path(path, expected):
"tcp://127.0.0.1", # Missing port
"tcp://[::1]", # Missing port for IPv6
"tcp://:5555", # Missing host
])
],
)
def test_split_zmq_path_invalid(invalid_path):
with pytest.raises(ValueError):
split_zmq_path(invalid_path)
......@@ -931,8 +935,9 @@ def test_make_zmq_socket_ipv6():
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
# Verify that the IPV6 option is set
assert zsock.getsockopt(
zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
assert zsock.getsockopt(zmq.IPV6) == 1, (
"IPV6 option should be enabled for IPv6 addresses"
)
# Clean up
zsock.close()
......@@ -1019,15 +1024,14 @@ def test_convert_ids_list_to_tokens():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
token_ids = tokenizer.encode("Hello, world!")
# token_ids = [9707, 11, 1879, 0]
assert tokenizer.convert_ids_to_tokens(token_ids) == [
'Hello', ',', 'Ġworld', '!'
]
assert tokenizer.convert_ids_to_tokens(token_ids) == ["Hello", ",", "Ġworld", "!"]
tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
assert tokens == ['Hello', ',', ' world', '!']
assert tokens == ["Hello", ",", " world", "!"]
def test_current_stream_multithread():
import threading
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
......@@ -1046,13 +1050,18 @@ def test_current_stream_multithread():
child_thread.start()
try:
assert thread_stream_ready.wait(
timeout=5), "Child thread failed to enter stream context in time"
assert thread_stream_ready.wait(timeout=5), (
"Child thread failed to enter stream context in time"
)
main_current_stream = current_stream()
assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread"
assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream"
assert main_current_stream != child_stream, (
"Main thread's current_stream was contaminated by child thread"
)
assert main_current_stream == main_default_stream, (
"Main thread's current_stream is not the default stream"
)
# Notify child thread it can exit
thread_can_exit.set()
......@@ -1070,7 +1079,7 @@ def test_load_config_file(tmp_path):
"enable-logging": True,
"list-arg": ["item1", "item2"],
"port": 12323,
"tensor-parallel-size": 4
"tensor-parallel-size": 4,
}
# Write the configuration data to a temporary YAML file
......
......@@ -16,9 +16,6 @@ from vllm.multimodal.inputs import (
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor
from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (
BlockHash,
FreeKVCacheBlockQueue,
......@@ -48,8 +45,6 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
# yapf: enable
pytestmark = pytest.mark.cpu_test
......
......@@ -22,8 +22,6 @@ from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available
# yapf: disable
from vllm.v1.sample.logits_processor import (
BatchUpdate,
BatchUpdateBuilder,
......@@ -34,8 +32,6 @@ from vllm.v1.sample.logits_processor import (
MoveDirectionality,
build_logitsprocs,
)
# yapf: enable
from vllm.v1.sample.metadata import SamplingMetadata
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
......
......@@ -7,8 +7,6 @@ from typing import Union
import pytest
from tests.utils import create_new_process_for_each_test
# yapf: disable
from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
......@@ -24,8 +22,6 @@ from tests.v1.logits_processors.utils import (
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
# yapf: enable
from vllm import LLM, SamplingParams
from vllm.v1.sample.logits_processor import (
STR_POOLING_REJECTS_LOGITSPROCS,
......
......@@ -11,8 +11,6 @@ import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_test
# yapf: disable
from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
......@@ -25,8 +23,6 @@ from tests.v1.logits_processors.utils import (
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
# yapf: enable
def _server_with_logitproc_entrypoint(
env_dict: Optional[dict[str, str]],
......
......@@ -4,7 +4,6 @@
import importlib
from typing import TYPE_CHECKING, Callable
# yapf: disable
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
......@@ -13,8 +12,6 @@ from vllm.distributed.kv_transfer.kv_connector.base import (
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# yapf: disable
import argparse
import copy
import dataclasses
......@@ -88,8 +87,6 @@ from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor
from vllm.v1.sample.logits_processor import LogitsProcessor
# yapf: enable
if TYPE_CHECKING:
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QuantizationMethods
......
......@@ -17,9 +17,6 @@ import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam,
......@@ -40,8 +37,6 @@ from openai.types.responses import ResponseInputImageParam
from openai_harmony import Message as OpenAIHarmonyMessage
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
# yapf: enable
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
# pydantic needs the TypedDict from typing_extensions
......@@ -52,11 +47,7 @@ from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MediaConnector
# yapf: disable
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid, supports_kw
......@@ -317,11 +308,7 @@ def _is_var_or_elems_access(
):
return _is_var_or_elems_access(node.node, varname, key)
# yapf: disable
return (
_is_attr_access(node, varname, key) if key
else _is_var_access(node, varname)
) # yapf: enable
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
......
......@@ -39,9 +39,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
resolve_chat_template_content_format,
)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.score_utils import (
ScoreContentPartParam,
ScoreMultiModalParam,
......@@ -50,8 +47,6 @@ from vllm.entrypoints.score_utils import (
compress_token_type_ids,
get_score_prompt,
)
# yapf: enable
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
DataPrompt,
......
......@@ -49,9 +49,6 @@ from vllm.entrypoints.chat_utils import (
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
......@@ -84,8 +81,6 @@ from vllm.entrypoints.openai.protocol import (
TranslationResponse,
UnloadLoRAAdapterRequest,
)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_classification import ServingClassification
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
......
......@@ -11,8 +11,6 @@ from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar
import regex as re
import torch
from fastapi import HTTPException, UploadFile
# yapf: disable
from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio,
)
......@@ -46,8 +44,6 @@ from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreated
from openai.types.responses import (
ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
# yapf: enable
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
......
......@@ -18,8 +18,6 @@ from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (
BatchRequestInput,
BatchRequestOutput,
......@@ -30,8 +28,6 @@ from vllm.entrypoints.openai.protocol import (
RerankResponse,
ScoreResponse,
)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
......
......@@ -1733,13 +1733,15 @@ class OpenAIServingChat(OpenAIServing):
is a tool call with arguments.
"""
# yapf: disable
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
output.finish_reason is not None
and self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0]
and self.enable_auto_tools
and self.tool_parser
and delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
)
......
......@@ -18,8 +18,6 @@ from vllm.entrypoints.openai.protocol import (
ErrorResponse,
UsageInfo,
)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (
ClassificationServeContext,
OpenAIServing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment