registry.py 55.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping, Set
5
from dataclasses import dataclass, field
6
from typing import Any, Literal
7

zhuwenwen's avatar
zhuwenwen committed
8
import os
9
10
11
import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
zhuwenwen's avatar
zhuwenwen committed
12
# from ..utils import models_path_prefix
13

zhuwenwen's avatar
zhuwenwen committed
14
models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_MODELS_PATH")
15

16
from vllm.config.model import ModelDType, TokenizerMode
17

zhuwenwen's avatar
zhuwenwen committed
18

19
20
21
22
23
24
25
26
@dataclass(frozen=True)
class _HfExamplesInfo:
    default: str
    """The default model to use for testing this architecture."""

    extras: Mapping[str, str] = field(default_factory=dict)
    """Extra models to use for testing this architecture."""

27
    tokenizer: str | None = None
28
29
    """Set the tokenizer to load for this architecture."""

30
    tokenizer_mode: TokenizerMode | str = "auto"
31
32
    """Set the tokenizer type for this architecture."""

33
    speculative_model: str | None = None
34
35
36
37
38
    """
    The default model to use for testing this architecture, which is only used
    for speculative decoding.
    """

39
40
41
42
43
    speculative_method: str | None = None
    """
    The method to use for speculative decoding.
    """

44
    min_transformers_version: str | None = None
45
46
47
48
    """
    The minimum version of HF Transformers that is required to run this model.
    """

49
    max_transformers_version: str | None = None
50
51
52
53
    """
    The maximum version of HF Transformers that this model runs on.
    """

54
    transformers_version_reason: dict[Literal["vllm", "hf"], str] | None = None
55
    """
56
57
58
    The type and reason to skip test for the minimum/maximum version requirement.
    vllm: skip all vLLM tests if the version requirement is not met.
    hf: only skip tests that uses HF runner if the version requirement is not met.
59
60
    """

61
    require_embed_inputs: bool = False
62
    """
63
64
    If `True`, enables prompt and multi-modal embedding inputs while
    disabling tokenization.
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    """

    dtype: ModelDType = "auto"
    """
    The data type for the model weights and activations.
    """

    enforce_eager: bool = False
    """
    Whether to enforce eager execution. If True, we will
    disable CUDA graph and always execute the model in eager mode.
    If False, we will use CUDA graph and eager execution in hybrid.
    """

79
80
    is_available_online: bool = True
    """
81
    Set this to `False` if the name of this architecture no longer exists on
82
83
84
85
86
87
    the HF repo. To maintain backwards compatibility, we have not removed them
    from the main model registry, so without this flag the registry tests will
    fail.
    """

    trust_remote_code: bool = False
88
    """The `trust_remote_code` level required to load the model."""
89

90
    hf_overrides: dict[str, Any] = field(default_factory=dict)
91
    """The `hf_overrides` required to load the model."""
92

93
    max_model_len: int | None = None
94
95
96
97
98
    """
    The maximum model length to use for this model. Some models default to a
    length that is too large to fit into memory in CI.
    """

99
100
101
102
103
    max_num_batched_tokens: int | None = None
    """
    The maximum number of tokens to be processed in a single batch.
    """

104
    revision: str | None = None
105
106
107
108
109
    """
    The specific revision (commit hash, tag, or branch) to use for the model.
    If not specified, the default revision will be used.
    """

110
    max_num_seqs: int | None = None
111
112
    """Maximum number of sequences to be processed in a single iteration."""

113
114
115
116
117
118
    use_original_num_layers: bool = False
    """
    If True, use the original number of layers from the model config 
    instead of minimal layers for testing.
    """

119
120
121
    def check_transformers_version(
        self,
        *,
122
        on_fail: Literal["error", "skip", "return"],
123
        check_version_reason: Literal["vllm", "hf"] = "hf",
124
125
        check_min_version: bool = True,
        check_max_version: bool = True,
126
    ) -> str | None:
127
128
129
130
        """
        If the installed transformers version does not meet the requirements,
        perform the given action.
        """
131
132
133
134
        if (
            self.min_transformers_version is None
            and self.max_transformers_version is None
        ):
135
            return None
136
137

        current_version = TRANSFORMERS_VERSION
138
        cur_base_version = Version(current_version).base_version
139
140
141
        min_version = self.min_transformers_version
        max_version = self.max_transformers_version
        msg = f"`transformers=={current_version}` installed, but `transformers"
142
143
        # Only check the base version for the min/max version, otherwise preview
        # models cannot be run because `x.yy.0.dev0`<`x.yy.0`
144
145
        if min_version and Version(cur_base_version) < Version(min_version):
            is_version_valid = not check_min_version
146
            msg += f">={min_version}` is required to run this model."
147
148
        elif max_version and Version(cur_base_version) > Version(max_version):
            is_version_valid = not check_max_version
149
150
            msg += f"<={max_version}` is required to run this model."
        else:
151
            is_version_valid = True
152

153
154
155
156
157
158
159
160
161
        # check if Transformers version breaks the corresponding model runner,
        # skip test when model runner not compatible
        is_reason_valid = not (
            check_version_reason
            and self.transformers_version_reason
            and check_version_reason in self.transformers_version_reason
        )
        is_transformers_valid = is_version_valid and is_reason_valid
        if is_transformers_valid:
162
            return None
163
164
165
        elif self.transformers_version_reason:
            for reason_type, reason in self.transformers_version_reason.items():
                msg += f" Reason({reason_type}): {reason}"
166
167
168

        if on_fail == "error":
            raise RuntimeError(msg)
169
        elif on_fail == "skip":
170
            pytest.skip(msg)
171

172
173
        return msg

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    def check_available_online(
        self,
        *,
        on_fail: Literal["error", "skip"],
    ) -> None:
        """
        If the model is not available online, perform the given action.
        """
        if not self.is_available_online:
            msg = "Model is not available online"

            if on_fail == "error":
                raise RuntimeError(msg)
            else:
                pytest.skip(msg)

190
191
192

_TEXT_GENERATION_EXAMPLE_MODELS = {
    # [Decoder-only]
193
    "AfmoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "arcee-ai/Trinity-Nano-Preview")),
194
195
196
    "ApertusForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "swiss-ai/Apertus-8B-Instruct-2509")),
    "AquilaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/AquilaChat-7B"), trust_remote_code=True),
    "AquilaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/AquilaChat2-7B"), trust_remote_code=True),
197
    "ArceeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "arcee-ai/AFM-4.5B-Base")),
198
    "ArcticForCausalLM": _HfExamplesInfo(
199
        os.path.join(models_path_prefix, "Snowflake/snowflake-arctic-instruct"), trust_remote_code=True
200
201
    ),
    "BaiChuanForCausalLM": _HfExamplesInfo(
202
        os.path.join(models_path_prefix, "baichuan-inc/Baichuan-7B"), trust_remote_code=True
203
204
    ),
    "BaichuanForCausalLM": _HfExamplesInfo(
205
        os.path.join(models_path_prefix, "baichuan-inc/Baichuan2-7B-chat"), trust_remote_code=True
206
207
    ),
    "BailingMoeForCausalLM": _HfExamplesInfo(
208
        os.path.join(models_path_prefix, "inclusionAI/Ling-lite-1.5"), trust_remote_code=True
209
210
    ),
    "BailingMoeV2ForCausalLM": _HfExamplesInfo(
211
        os.path.join(models_path_prefix, "inclusionAI/Ling-mini-2.0"), trust_remote_code=True
212
213
    ),
    "BambaForCausalLM": _HfExamplesInfo(
214
215
        os.path.join(models_path_prefix, "ibm-ai-platform/Bamba-9B-v1"),
        extras={"tiny": os.path.join(models_path_prefix, "hmellor/tiny-random-BambaForCausalLM")},
216
217
    ),
    "BloomForCausalLM": _HfExamplesInfo(
218
        "bigscience/bloom-560m", {"1b": os.path.join(models_path_prefix, "bigscience/bloomz-1b1")}
219
220
    ),
    "ChatGLMModel": _HfExamplesInfo(
221
        os.path.join(models_path_prefix, "zai-org/chatglm3-6b"), trust_remote_code=True, max_transformers_version="4.48"
222
223
    ),
    "ChatGLMForConditionalGeneration": _HfExamplesInfo(
224
        os.path.join(models_path_prefix, "thu-coai/ShieldLM-6B-chatglm3"),
225
226
227
        trust_remote_code=True,
    ),
    "CohereForCausalLM": _HfExamplesInfo(
228
        os.path.join(models_path_prefix, "CohereLabs/c4ai-command-r-v01"), trust_remote_code=True
229
230
    ),
    "Cohere2ForCausalLM": _HfExamplesInfo(
231
        os.path.join(models_path_prefix, "CohereLabs/c4ai-command-r7b-12-2024"),
232
233
        trust_remote_code=True,
    ),
234
    "CwmForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "facebook/cwm"), min_transformers_version="4.58"),
235
236
    # FIXME: databricks/dbrx-instruct has been deleted
    "DbrxForCausalLM": _HfExamplesInfo(
237
        os.path.join(models_path_prefix, "databricks/dbrx-instruct"), is_available_online=False
238
    ),
239
    "DeciLMForCausalLM": _HfExamplesInfo(
240
        os.path.join(models_path_prefix, "nvidia/Llama-3_3-Nemotron-Super-49B-v1"),
241
242
        trust_remote_code=True,
    ),
243
    "DeepseekForCausalLM": _HfExamplesInfo(
244
        os.path.join(models_path_prefix, "deepseek-ai/deepseek-moe-16b-base"),
245
246
        trust_remote_code=True,
    ),
247
    "DeepseekV2ForCausalLM": _HfExamplesInfo(
248
        os.path.join(models_path_prefix, "deepseek-ai/DeepSeek-V2-Lite-Chat"),
249
250
251
        trust_remote_code=True,
    ),
    "DeepseekV3ForCausalLM": _HfExamplesInfo(
252
        os.path.join(models_path_prefix, "deepseek-ai/DeepSeek-V3"),
253
254
        trust_remote_code=True,
    ),
255
    "DeepseekV32ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "deepseek-ai/DeepSeek-V3.2-Exp")),
256
257
    "Ernie4_5ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-0.3B-PT")),
    "Ernie4_5_MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT")),
258
    "ExaoneForCausalLM": _HfExamplesInfo(
259
260
261
        os.path.join(models_path_prefix, "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), trust_remote_code=True
    ),
    "Exaone4ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "LGAI-EXAONE/EXAONE-4.0-32B")),
Kyungmin Lee's avatar
Kyungmin Lee committed
262
263
264
    "ExaoneMoEForCausalLM": _HfExamplesInfo(
        "LGAI-EXAONE/K-EXAONE-236B-A23B", min_transformers_version="5.0.0"
    ),
265
    "Fairseq2LlamaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mgleize/fairseq2-dummy-Llama-3.2-1B")),
266
    "FalconForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "tiiuae/falcon-7b")),
267
268
    "FalconH1ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "tiiuae/Falcon-H1-0.5B-Base")),
    "FlexOlmoForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "allenai/Flex-reddit-2x7B-1T")),
269
    "GemmaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-1.1-2b-it")),
270
    "Gemma2ForCausalLM": _HfExamplesInfo(
271
        "google/gemma-2-9b", extras={"tiny": os.path.join(models_path_prefix, "google/gemma-2-2b-it")}
272
    ),
273
    "Gemma3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3-1b-it")),
274
    "Gemma3nForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3n-E2B-it")),
275
276
    "GlmForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/glm-4-9b-chat-hf")),
    "Glm4ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4-9B-0414")),
277
    "Glm4MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.5")),
278
    "Glm4MoeLiteForCausalLM": _HfExamplesInfo(
279
        os.path.join(models_path_prefix, "zai-org/GLM-4.7-Flash"),
280
281
282
        min_transformers_version="5.0.0.dev",
        is_available_online=False,
    ),
283
    "GPT2LMHeadModel": _HfExamplesInfo(os.path.join(models_path_prefix, "openai-community/gpt2"), {"alias": os.path.join(models_path_prefix, "gpt2")}),
284
    "GPTBigCodeForCausalLM": _HfExamplesInfo(
285
        os.path.join(models_path_prefix, "bigcode/starcoder"),
286
        extras={
287
288
            "tiny": os.path.join(models_path_prefix, "bigcode/tiny_starcoder_py"),
            "santacoder": os.path.join(models_path_prefix, "bigcode/gpt_bigcode-santacoder"),
289
        },
290
291
    ),
    "GPTJForCausalLM": _HfExamplesInfo(
292
        os.path.join(models_path_prefix, "Milos/slovak-gpt-j-405M"), {"6b": os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b")}
293
294
    ),
    "GPTNeoXForCausalLM": _HfExamplesInfo(
295
        os.path.join(models_path_prefix, "EleutherAI/pythia-70m"), {"1b": os.path.join(models_path_prefix, "EleutherAI/pythia-1.4b")}
296
    ),
297
    "GptOssForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "lmsys/gpt-oss-20b-bf16")),
298
    "GraniteForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "ibm/PowerLM-3b")),
299
    "GraniteMoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "ibm/PowerMoE-3b")),
300
    "GraniteMoeHybridForCausalLM": _HfExamplesInfo(
301
        os.path.join(models_path_prefix, "ibm-granite/granite-4.0-tiny-preview")
302
303
    ),
    "GraniteMoeSharedForCausalLM": _HfExamplesInfo(
304
        os.path.join(models_path_prefix, "ibm-research/moe-7b-1b-active-shared-experts")
305
306
    ),
    "Grok1ModelForCausalLM": _HfExamplesInfo(
307
        os.path.join(models_path_prefix, "hpcai-tech/grok-1"), trust_remote_code=True
308
    ),
309
    "Grok1ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "xai-org/grok-2"), trust_remote_code=True),
310
    "HunYuanDenseV1ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "tencent/Hunyuan-7B-Instruct")),
311
    "HunYuanMoEV1ForCausalLM": _HfExamplesInfo(
312
        os.path.join(models_path_prefix, "tencent/Hunyuan-A13B-Instruct"), trust_remote_code=True
313
314
    ),
    "InternLMForCausalLM": _HfExamplesInfo(
315
        os.path.join(models_path_prefix, "internlm/internlm-chat-7b"), trust_remote_code=True
316
317
    ),
    "InternLM2ForCausalLM": _HfExamplesInfo(
318
        os.path.join(models_path_prefix, "internlm/internlm2-chat-7b"), trust_remote_code=True
319
320
    ),
    "InternLM2VEForCausalLM": _HfExamplesInfo(
321
        os.path.join(models_path_prefix, "OpenGVLab/Mono-InternVL-2B"), trust_remote_code=True
322
323
    ),
    "InternLM3ForCausalLM": _HfExamplesInfo(
324
        os.path.join(models_path_prefix, "internlm/internlm3-8b-instruct"), trust_remote_code=True
325
    ),
326
    "JAISLMHeadModel": _HfExamplesInfo(os.path.join(models_path_prefix, "inceptionai/jais-13b-chat")),
327
    "Jais2ForCausalLM": _HfExamplesInfo(
328
        os.path.join(models_path_prefix, "inceptionai/Jais-2-8B-Chat"), min_transformers_version="4.58"
329
    ),
330
    "IQuestCoderForCausalLM": _HfExamplesInfo(
331
        os.path.join(models_path_prefix, "IQuestLab/IQuest-Coder-V1-40B-Instruct"), trust_remote_code=True
332
333
    ),
    "IQuestLoopCoderForCausalLM": _HfExamplesInfo(
334
        os.path.join(models_path_prefix, "IQuestLab/IQuest-Coder-V1-40B-Loop-Instruct"), trust_remote_code=True
335
    ),
336
    "JAISLMHeadModel": _HfExamplesInfo(os.path.join(models_path_prefix, "inceptionai/jais-13b-chat")),
337
    "Jais2ForCausalLM": _HfExamplesInfo(
338
        os.path.join(models_path_prefix, "inceptionai/Jais-2-8B-Chat"), min_transformers_version="4.58"
339
    ),
340
    "JambaForCausalLM": _HfExamplesInfo(
341
        os.path.join(models_path_prefix, "ai21labs/AI21-Jamba-1.5-Mini"),
342
        extras={
343
            "tiny": os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-dev"),
344
            "random": os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-random"),
345
346
        },
    ),
347
    "KimiLinearForCausalLM": _HfExamplesInfo(
348
        os.path.join(models_path_prefix, "moonshotai/Kimi-Linear-48B-A3B-Instruct"), trust_remote_code=True
349
    ),
350
    "Lfm2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "LiquidAI/LFM2-1.2B")),
Paul Pak's avatar
Paul Pak committed
351
    "Lfm2MoeForCausalLM": _HfExamplesInfo(
352
        os.path.join(models_path_prefix, "LiquidAI/LFM2-8B-A1B"), min_transformers_version="4.58"
Paul Pak's avatar
Paul Pak committed
353
    ),
354
    "LlamaForCausalLM": _HfExamplesInfo(
355
        os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
356
        extras={
357
358
359
360
            "guard": os.path.join(models_path_prefix, "meta-llama/Llama-Guard-3-1B"),
            "hermes": os.path.join(models_path_prefix, "NousResearch/Hermes-3-Llama-3.1-8B"),
            "fp8": os.path.join(models_path_prefix, "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"),
            "tiny": os.path.join(models_path_prefix, "hmellor/tiny-random-LlamaForCausalLM"),
361
362
363
        },
    ),
    "LLaMAForCausalLM": _HfExamplesInfo(
364
        os.path.join(models_path_prefix, "decapoda-research/llama-7b-hf"), is_available_online=False
365
366
    ),
    "Llama4ForCausalLM": _HfExamplesInfo(
367
        os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
368
369
    ),
    "LongcatFlashForCausalLM": _HfExamplesInfo(
370
        os.path.join(models_path_prefix, "meituan-longcat/LongCat-Flash-Chat"), trust_remote_code=True
371
    ),
372
    "MambaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "state-spaces/mamba-130m-hf")),
373
    "Mamba2ForCausalLM": _HfExamplesInfo(
374
        os.path.join(models_path_prefix, "mistralai/Mamba-Codestral-7B-v0.1"),
375
        extras={
376
            "random": os.path.join(models_path_prefix, "yujiepan/mamba2-codestral-v0.1-tiny-random"),
377
378
        },
    ),
379
    "FalconMambaForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "tiiuae/falcon-mamba-7b-instruct")),
380
    "MiniCPMForCausalLM": _HfExamplesInfo(
381
        os.path.join(models_path_prefix, "openbmb/MiniCPM-2B-sft-bf16"), trust_remote_code=True
382
383
    ),
    "MiniCPM3ForCausalLM": _HfExamplesInfo(
384
        os.path.join(models_path_prefix, "openbmb/MiniCPM3-4B"), trust_remote_code=True
385
    ),
386
    "MiniCPM4ForCausalLM": _HfExamplesInfo(
387
        os.path.join(models_path_prefix, "openbmb/MiniCPM4.1-8B"), trust_remote_code=True
388
    ),
389
    "MiniMaxForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-Text-01-hf")),
390
    "MiniMaxText01ForCausalLM": _HfExamplesInfo(
391
        os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-Text-01"),
392
393
394
395
        trust_remote_code=True,
        revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3",
    ),
    "MiniMaxM1ForCausalLM": _HfExamplesInfo(
396
        os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-M1-40k"), trust_remote_code=True
397
    ),
398
    "MiniMaxM2ForCausalLM": _HfExamplesInfo(
399
        os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-M2"),
youkaichao's avatar
youkaichao committed
400
        trust_remote_code=True,
401
    ),
zhuwenwen's avatar
zhuwenwen committed
402
    "MistralForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mistralai/Mistral-7B-Instruct-v0.1")),
403
    "MistralLarge3ForCausalLM": _HfExamplesInfo(
404
        os.path.join(models_path_prefix, "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4")
405
    ),
406
    "MixtralForCausalLM": _HfExamplesInfo(
407
408
        os.path.join(models_path_prefix, "mistralai/Mixtral-8x7B-Instruct-v0.1"),
        {"tiny": os.path.join(models_path_prefix, "TitanML/tiny-mixtral")},
409
    ),
zhuwenwen's avatar
zhuwenwen committed
410
    "MptForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mpt"), is_available_online=False),
411
    # FIXME: mosaicml/mpt-7b has been deleted
412
    "MPTForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mosaicml/mpt-7b"), is_available_online=False),
zhuwenwen's avatar
zhuwenwen committed
413
    "NemotronForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "nvidia/Minitron-8B-Base")),
414
    "NemotronHForCausalLM": _HfExamplesInfo(
415
        os.path.join(models_path_prefix, "nvidia/Nemotron-H-8B-Base-8K"), trust_remote_code=True
416
    ),
zhuwenwen's avatar
zhuwenwen committed
417
    "OlmoForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "allenai/OLMo-1B-hf")),
zhuwenwen's avatar
zhuwenwen committed
418
    "Olmo2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "allenai/OLMo-2-0425-1B")),
419
    "Olmo3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "allenai/Olmo-3-7B-Instruct")),
zhuwenwen's avatar
zhuwenwen committed
420
    "OlmoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "allenai/OLMoE-1B-7B-0924-Instruct")),
421
    "OpenPanguMTPModel": _HfExamplesInfo(
422
        os.path.join(models_path_prefix, "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1"),
423
424
425
        trust_remote_code=True,
        is_available_online=False,
    ),
426
    "OPTForCausalLM": _HfExamplesInfo(
427
        os.path.join(models_path_prefix, "facebook/opt-125m"), {"1b": os.path.join(models_path_prefix, "facebook/opt-iml-max-1.3b")}
428
429
    ),
    "OrionForCausalLM": _HfExamplesInfo(
430
        os.path.join(models_path_prefix, "OrionStarAI/Orion-14B-Chat"), trust_remote_code=True
431
    ),
432
    "OuroForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "ByteDance/Ouro-1.4B"), trust_remote_code=True),
433
    "PanguEmbeddedForCausalLM": _HfExamplesInfo(
434
        os.path.join(models_path_prefix, "FreedomIntelligence/openPangu-Embedded-7B-V1.1"), trust_remote_code=True
435
    ),
436
437
438
439
440
    "PanguProMoEV2ForCausalLM": _HfExamplesInfo(
        "",
        trust_remote_code=True,
        is_available_online=False,
    ),
441
    "PanguUltraMoEForCausalLM": _HfExamplesInfo(
442
        os.path.join(models_path_prefix, "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1"),
443
444
445
        trust_remote_code=True,
        is_available_online=False,
    ),
446
447
448
    "PersimmonForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "adept/persimmon-8b-chat")),
    "PhiForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "microsoft/phi-2")),
    "Phi3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "microsoft/Phi-3-mini-4k-instruct")),
449
    "PhiMoEForCausalLM": _HfExamplesInfo(
450
        os.path.join(models_path_prefix, "microsoft/Phi-3.5-MoE-instruct"), trust_remote_code=True
451
452
    ),
    "Plamo2ForCausalLM": _HfExamplesInfo(
453
        os.path.join(models_path_prefix, "pfnet/plamo-2-1b"),
454
455
        trust_remote_code=True,
    ),
456
    "Plamo3ForCausalLM": _HfExamplesInfo(
457
        os.path.join(models_path_prefix, "pfnet/plamo-3-nict-2b-base"),
458
459
        trust_remote_code=True,
    ),
460
    "QWenLMHeadModel": _HfExamplesInfo(
461
        os.path.join(models_path_prefix, "Qwen/Qwen-7B-Chat"),
462
        max_transformers_version="4.53",
463
464
465
        transformers_version_reason={
            "hf": "HF model uses remote code that is not compatible with latest Transformers"  # noqa: E501
        },
466
467
468
        trust_remote_code=True,
    ),
    "Qwen2ForCausalLM": _HfExamplesInfo(
469
        os.path.join(models_path_prefix, "Qwen/Qwen2-0.5B-Instruct"),
470
        extras={
471
472
            "2.5": os.path.join(models_path_prefix, "Qwen/Qwen2.5-0.5B-Instruct"),
            "2.5-1.5B": os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct"),
473
        },
474
    ),
475
476
477
    "Qwen2MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen1.5-MoE-A2.7B-Chat")),
    "Qwen3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-8B")),
    "Qwen3MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-30B-A3B")),
478
    "Qwen3NextForCausalLM": _HfExamplesInfo(
479
480
        os.path.join(models_path_prefix, "Qwen/Qwen3-Next-80B-A3B-Instruct"),
        extras={"tiny-random": os.path.join(models_path_prefix, "tiny-random/qwen3-next-moe")},
481
482
        min_transformers_version="4.56.3",
    ),
483
    "RWForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "tiiuae/falcon-40b")),
484
    "SeedOssForCausalLM": _HfExamplesInfo(
485
        os.path.join(models_path_prefix, "ByteDance-Seed/Seed-OSS-36B-Instruct"),
486
487
        trust_remote_code=True,
    ),
Li Xie's avatar
Li Xie committed
488
489
490
    "Step1ForCausalLM": _HfExamplesInfo(
        "stepfun-ai/Step-Audio-EditX", trust_remote_code=True
    ),
491
492
493
494
495
    "SmolLM3ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "HuggingFaceTB/SmolLM3-3B")),
    "StableLMEpochForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stabilityai/stablelm-zephyr-3b")),
    "StableLmForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stabilityai/stablelm-3b-4e1t")),
    "Starcoder2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "bigcode/starcoder2-3b")),
    "Step3TextForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "stepfun-ai/step3"), trust_remote_code=True),
496
    "SolarForCausalLM": _HfExamplesInfo(
497
        os.path.join(models_path_prefix, "upstage/solar-pro-preview-instruct"), trust_remote_code=True
498
    ),
499
    "TeleChatForCausalLM": _HfExamplesInfo(
500
        os.path.join(models_path_prefix, "chuhac/TeleChat2-35B"), trust_remote_code=True
501
    ),
502
    "TeleChat2ForCausalLM": _HfExamplesInfo(
503
        os.path.join(models_path_prefix, "Tele-AI/TeleChat2-3B"), trust_remote_code=True
504
505
    ),
    "TeleFLMForCausalLM": _HfExamplesInfo(
506
        os.path.join(models_path_prefix, "CofeAI/FLM-2-52B-Instruct-2407"), trust_remote_code=True
507
508
    ),
    "XverseForCausalLM": _HfExamplesInfo(
509
510
        os.path.join(models_path_prefix, "xverse/XVERSE-7B-Chat"),
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Llama-2-7b"),
511
512
        trust_remote_code=True,
    ),
zhuwenwen's avatar
zhuwenwen committed
513
    "Zamba2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Zyphra/Zamba2-7B-instruct")),
514
    "MiMoForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL"), trust_remote_code=True),
515
    "MiMoV2FlashForCausalLM": _HfExamplesInfo(
516
        os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-V2-Flash"), trust_remote_code=True
517
    ),
518
    "Dots1ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "rednote-hilab/dots.llm1.inst")),
519
520
521
522
}

_EMBEDDING_EXAMPLE_MODELS = {
    # [Text-only]
523
    "BertModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-base-en-v1.5")),
524
    "BgeM3EmbeddingModel": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-m3")),
525
    "Gemma2Model": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-multilingual-gemma2")),
526
    "Gemma3TextModel": _HfExamplesInfo(os.path.join(models_path_prefix, "google/embeddinggemma-300m")),
527
    "GritLM": _HfExamplesInfo(os.path.join(models_path_prefix, "parasail-ai/GritLM-7B-vllm")),
528
    "GteModel": _HfExamplesInfo(
529
        os.path.join(models_path_prefix, "Snowflake/snowflake-arctic-embed-m-v2.0"), trust_remote_code=True
530
531
    ),
    "GteNewModel": _HfExamplesInfo(
532
        os.path.join(models_path_prefix, "Alibaba-NLP/gte-base-en-v1.5"),
533
534
535
536
        trust_remote_code=True,
        hf_overrides={"architectures": ["GteNewModel"]},
    ),
    "InternLM2ForRewardModel": _HfExamplesInfo(
537
        os.path.join(models_path_prefix, "internlm/internlm2-1_8b-reward"), trust_remote_code=True
538
    ),
539
    "JambaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "ai21labs/Jamba-tiny-reward-dev")),
540
    "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
541
    "LlamaBidirectionalModel": _HfExamplesInfo(
542
        os.path.join(models_path_prefix, "nvidia/llama-nemotron-embed-1b-v2"), trust_remote_code=True
543
    ),
544
    "MistralModel": _HfExamplesInfo(os.path.join(models_path_prefix, "intfloat/e5-mistral-7b-instruct")),
545
    "ModernBertModel": _HfExamplesInfo(
546
        os.path.join(models_path_prefix, "Alibaba-NLP/gte-modernbert-base"), trust_remote_code=True
547
548
    ),
    "NomicBertModel": _HfExamplesInfo(
549
        os.path.join(models_path_prefix, "nomic-ai/nomic-embed-text-v2-moe"), trust_remote_code=True
550
    ),
551
    "Qwen2Model": _HfExamplesInfo(os.path.join(models_path_prefix, "ssmits/Qwen2-7B-Instruct-embed-base")),
552
    "Qwen2ForRewardModel": _HfExamplesInfo(
553
        os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"),
554
        max_transformers_version="4.53",
555
556
557
        transformers_version_reason={
            "hf": "HF model uses remote code that is not compatible with latest Transformers"  # noqa: E501
        },
558
559
    ),
    "Qwen2ForProcessRewardModel": _HfExamplesInfo(
560
        os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-PRM-7B"),
561
        max_transformers_version="4.53",
562
563
564
        transformers_version_reason={
            "hf": "HF model uses remote code that is not compatible with latest Transformers"  # noqa: E501
        },
565
    ),
566
567
568
    "RobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/stsb-roberta-base-v2")),
    "RobertaForMaskedLM": _HfExamplesInfo(os.path.join(models_path_prefix, "sentence-transformers/all-roberta-large-v1")),
    "XLMRobertaModel": _HfExamplesInfo(os.path.join(models_path_prefix, "intfloat/multilingual-e5-small")),
569
    "BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
570
        os.path.join(models_path_prefix, "naver/splade-v3"),
571
        hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
572
    ),
573
    # [Multimodal]
574
    "CLIPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "openai/clip-vit-base-patch32")),
zhuwenwen's avatar
zhuwenwen committed
575
    "LlavaNextForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "royokong/e5-v")),
576
    "Phi3VForCausalLM": _HfExamplesInfo(
577
        os.path.join(models_path_prefix, "TIGER-Lab/VLM2Vec-Full"), trust_remote_code=True
578
    ),
579
580
    "Qwen2VLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "MrLight/dse-qwen2-2b-mrl-v1")),
    "SiglipModel": _HfExamplesInfo(os.path.join(models_path_prefix, "google/siglip-base-patch16-224")),
581
    "PrithviGeoSpatialMAE": _HfExamplesInfo(
582
        os.path.join(models_path_prefix, "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"),
583
        dtype="float16",
584
        enforce_eager=True,
585
586
        require_embed_inputs=True,
        # This is to avoid the model going OOM in CI
587
588
589
        max_num_seqs=32,
    ),
    "Terratorch": _HfExamplesInfo(
590
        os.path.join(models_path_prefix, "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"),
591
        dtype="float16",
592
        enforce_eager=True,
593
        require_embed_inputs=True,
594
595
596
        # This is to avoid the model going OOM in CI
        max_num_seqs=32,
    ),
597
598
}

599
600
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
    # [Decoder-only]
601
    "GPT2ForSequenceClassification": _HfExamplesInfo(
602
        os.path.join(models_path_prefix, "nie3e/sentiment-polish-gpt2-small")
603
    ),
604
    # [Cross-encoder]
605
    "BertForSequenceClassification": _HfExamplesInfo(
606
        os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2")
607
    ),
608
    "BertForTokenClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "boltuix/NeuroBERT-NER")),
609
    "GteNewForSequenceClassification": _HfExamplesInfo(
610
        os.path.join(models_path_prefix, "Alibaba-NLP/gte-multilingual-reranker-base"),
611
612
613
        trust_remote_code=True,
        hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
    ),
614
615
616
    "LlamaBidirectionalForSequenceClassification": _HfExamplesInfo(
        "nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True
    ),
617
    "ModernBertForSequenceClassification": _HfExamplesInfo(
618
        os.path.join(models_path_prefix, "Alibaba-NLP/gte-reranker-modernbert-base")
619
    ),
620
    "ModernBertForTokenClassification": _HfExamplesInfo(
621
        os.path.join(models_path_prefix, "disham993/electrical-ner-ModernBERT-base")
622
    ),
623
    "RobertaForSequenceClassification": _HfExamplesInfo(
624
        os.path.join(models_path_prefix, "cross-encoder/quora-roberta-base")
625
    ),
626
    "XLMRobertaForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-m3")),
627
628
}

629
630
_AUTOMATIC_CONVERTED_MODELS = {
    # Use as_seq_cls_model for automatic conversion
631
    "GemmaForSequenceClassification": _HfExamplesInfo(
632
        os.path.join(models_path_prefix, "BAAI/bge-reranker-v2-gemma"),
633
634
635
636
637
638
639
        hf_overrides={
            "architectures": ["GemmaForSequenceClassification"],
            "classifier_from_token": ["Yes"],
            "method": "no_post_processing",
        },
    ),
    "LlamaForSequenceClassification": _HfExamplesInfo(
640
        os.path.join(models_path_prefix, "Skywork/Skywork-Reward-V2-Llama-3.2-1B")
641
    ),
642
    "Qwen2ForSequenceClassification": _HfExamplesInfo(os.path.join(models_path_prefix, "jason9693/Qwen2.5-1.5B-apeach")),
643
    "Qwen3ForSequenceClassification": _HfExamplesInfo(
644
        os.path.join(models_path_prefix, "tomaarsen/Qwen3-Reranker-0.6B-seq-cls")
645
    ),
646
    "Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"),
647
648
649
650
651
652
653
654
655
    "Qwen3VLForSequenceClassification": _HfExamplesInfo(
        "Qwen/Qwen3-VL-Reranker-2B",
        is_available_online=False,
        hf_overrides={
            "architectures": ["Qwen3VLForSequenceClassification"],
            "classifier_from_token": ["no", "yes"],
            "is_original_qwen3_reranker": True,
        },
    ),
656
657
}

658
659
_MULTIMODAL_EXAMPLE_MODELS = {
    # [Decoder-only]
660
    "AriaForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "rhymes-ai/Aria")),
661
    "AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
662
        os.path.join(models_path_prefix, "nvidia/audio-flamingo-3-hf"), min_transformers_version="5.0.0.dev"
663
    ),
664
665
    "AyaVisionForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "CohereLabs/aya-vision-8b")),
    "BagelForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "ByteDance-Seed/BAGEL-7B-MoT")),
666
    "BeeForConditionalGeneration": _HfExamplesInfo(
667
        os.path.join(models_path_prefix, "Open-Bee/Bee-8B-RL"),
668
669
        trust_remote_code=True,
    ),
670
    "Blip2ForConditionalGeneration": _HfExamplesInfo(
671
672
        os.path.join(models_path_prefix, "Salesforce/blip2-opt-2.7b"),
        extras={"6b": os.path.join(models_path_prefix, "Salesforce/blip2-opt-6.7b")},
673
    ),
674
    "ChameleonForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "facebook/chameleon-7b")),
675
    "Cohere2VisionForConditionalGeneration": _HfExamplesInfo(
676
        os.path.join(models_path_prefix, "CohereLabs/command-a-vision-07-2025")
677
678
    ),
    "DeepseekVLV2ForCausalLM": _HfExamplesInfo(
679
680
        os.path.join(models_path_prefix, "deepseek-ai/deepseek-vl2-tiny"),
        extras={"fork": os.path.join(models_path_prefix, "Isotr0py/deepseek-vl2-tiny")},
681
        max_transformers_version="4.48",
682
        transformers_version_reason={"hf": "HF model is not compatible."},
683
684
        hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
    ),
685
    "DeepseekOCRForCausalLM": _HfExamplesInfo(
686
        os.path.join(models_path_prefix, "deepseek-ai/DeepSeek-OCR"),
687
    ),
688
    "DotsOCRForCausalLM": _HfExamplesInfo(
689
       os.path.join(models_path_prefix,  "rednote-hilab/dots.ocr"), trust_remote_code=True
690
    ),
691
    "Eagle2_5_VLForConditionalGeneration": _HfExamplesInfo(
692
        os.path.join(models_path_prefix, "nvidia/Eagle2.5-8B"), trust_remote_code=True, is_available_online=False
693
    ),
694
    "Emu3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/Emu3-Chat-hf")),
695
    "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo(
696
        os.path.join(models_path_prefix, "baidu/ERNIE-4.5-VL-28B-A3B-PT"),
697
698
        trust_remote_code=True,
    ),
699
700
    "FuyuForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "adept/fuyu-8b")),
    "Gemma3ForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3-4b-it")),
701
    "Gemma3nForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3n-E2B-it")),
702
    "GlmAsrForConditionalGeneration": _HfExamplesInfo(
703
        os.path.join(models_path_prefix, "zai-org/GLM-ASR-Nano-2512"),
704
705
706
        trust_remote_code=True,
        min_transformers_version="5.0",
    ),
707
    "GraniteVision": _HfExamplesInfo("ibm-granite/granite-vision-3.3-2b"),
708
    "GraniteSpeechForConditionalGeneration": _HfExamplesInfo(
709
        os.path.join(models_path_prefix, "ibm-granite/granite-speech-3.3-2b")
710
711
    ),
    "GLM4VForCausalLM": _HfExamplesInfo(
712
        os.path.join(models_path_prefix, "zai-org/glm-4v-9b"),
713
714
715
        trust_remote_code=True,
        hf_overrides={"architectures": ["GLM4VForCausalLM"]},
    ),
716
717
    "Glm4vForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking")),
    "Glm4vMoeForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "zai-org/GLM-4.5V")),
718
    "H2OVLChatModel": _HfExamplesInfo(
719
        os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-800m"),
720
        trust_remote_code=True,
721
        extras={"2b": os.path.join(models_path_prefix, "h2oai/h2ovl-mississippi-2b")},
722
        max_transformers_version="4.48",
723
        transformers_version_reason={"hf": "HF model is not compatible."},
724
725
    ),
    "HCXVisionForCausalLM": _HfExamplesInfo(
726
        os.path.join(models_path_prefix, "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"),
727
728
        trust_remote_code=True,
    ),
729
    "HunYuanVLForConditionalGeneration": _HfExamplesInfo(
730
        os.path.join(models_path_prefix, "tencent/HunyuanOCR"),
731
        hf_overrides={"num_experts": 0},
732
    ),
733
    "Idefics3ForConditionalGeneration": _HfExamplesInfo(
734
735
        os.path.join(models_path_prefix, "HuggingFaceM4/Idefics3-8B-Llama3"),
        extras={"tiny": os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM-256M-Instruct")},
736
    ),
oscardev256's avatar
oscardev256 committed
737
738
739
    "IsaacForConditionalGeneration": _HfExamplesInfo(
        "PerceptronAI/Isaac-0.1",
        trust_remote_code=True,
740
        extras={"0.2-2B-Preview": "PerceptronAI/Isaac-0.2-2B-Preview"},
oscardev256's avatar
oscardev256 committed
741
    ),
742
    "InternS1ForConditionalGeneration": _HfExamplesInfo(
743
        os.path.join(models_path_prefix, "internlm/Intern-S1"), trust_remote_code=True
744
745
    ),
    "InternVLChatModel": _HfExamplesInfo(
746
        os.path.join(models_path_prefix, "OpenGVLab/InternVL2-1B"),
747
        extras={
748
749
750
751
752
            "2B": os.path.join(models_path_prefix, "OpenGVLab/InternVL2-2B"),
            "3.0": os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B"),
            "3.5-qwen3": os.path.join(models_path_prefix, "OpenGVLab/InternVL3_5-1B"),
            "3.5-qwen3moe": os.path.join(models_path_prefix, "OpenGVLab/InternVL3_5-30B-A3B"),
            "3.5-gptoss": os.path.join(models_path_prefix, "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview"),
753
754
755
        },
        trust_remote_code=True,
    ),
756
    "InternVLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "OpenGVLab/InternVL3-1B-hf")),
757
    "KananaVForConditionalGeneration": _HfExamplesInfo(
758
        os.path.join(models_path_prefix, "kakaocorp/kanana-1.5-v-3b-instruct"),
759
760
        trust_remote_code=True,
    ),
761
    "KeyeForConditionalGeneration": _HfExamplesInfo(
762
        os.path.join(models_path_prefix, "Kwai-Keye/Keye-VL-8B-Preview"),
763
764
765
        trust_remote_code=True,
    ),
    "KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo(
766
        os.path.join(models_path_prefix, "Kwai-Keye/Keye-VL-1_5-8B"),
767
768
769
        trust_remote_code=True,
    ),
    "KimiVLForConditionalGeneration": _HfExamplesInfo(
770
771
        os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Instruct"),
        extras={"thinking": os.path.join(models_path_prefix, "moonshotai/Kimi-VL-A3B-Thinking")},
772
        trust_remote_code=True,
773
        max_transformers_version="4.53.3",
774
775
776
777
778
779
780
        transformers_version_reason={
            "hf": (
                "HF model uses deprecated transformers API "
                "(PytorchGELUTanh, DynamicCache.seen_tokens, and more). See: "
                "https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31"
            )
        },
781
    ),
782
    "LightOnOCRForConditionalGeneration": _HfExamplesInfo(
783
        os.path.join(models_path_prefix, "lightonai/LightOnOCR-1B-1025")
784
    ),
785
786
787
788
    "Lfm2VlForConditionalGeneration": _HfExamplesInfo(
        "LiquidAI/LFM2-VL-450M",
        min_transformers_version="5.0.0",
    ),
789
    "Llama4ForConditionalGeneration": _HfExamplesInfo(
790
        os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
791
        max_model_len=10240,
792
        extras={"llama-guard-4": os.path.join(models_path_prefix, "meta-llama/Llama-Guard-4-12B")},
793
794
    ),
    "LlavaForConditionalGeneration": _HfExamplesInfo(
795
        os.path.join(models_path_prefix, "llava-hf/llava-1.5-7b-hf"),
796
        extras={
797
798
            "mistral": os.path.join(models_path_prefix, "mistral-community/pixtral-12b"),
            "mistral-fp8": os.path.join(models_path_prefix, "nm-testing/pixtral-12b-FP8-dynamic"),
799
800
801
        },
    ),
    "LlavaNextForConditionalGeneration": _HfExamplesInfo(
802
        os.path.join(models_path_prefix, "llava-hf/llava-v1.6-mistral-7b-hf")
803
804
    ),
    "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo(
805
        os.path.join(models_path_prefix, "llava-hf/LLaVA-NeXT-Video-7B-hf")
806
807
    ),
    "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo(
808
        os.path.join(models_path_prefix, "llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
809
810
    ),
    "MantisForConditionalGeneration": _HfExamplesInfo(
811
        os.path.join(models_path_prefix, "TIGER-Lab/Mantis-8B-siglip-llama3"),
812
        max_transformers_version="4.48",
813
        transformers_version_reason={"hf": "HF model is not compatible."},
814
815
816
        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
    ),
    "MiDashengLMModel": _HfExamplesInfo(
817
        os.path.join(models_path_prefix, "mispeech/midashenglm-7b"), trust_remote_code=True
818
    ),
819
    "MiniCPMO": _HfExamplesInfo(os.path.join(models_path_prefix, "openbmb/MiniCPM-o-2_6"), trust_remote_code=True),
820
    "MiniCPMV": _HfExamplesInfo(
821
        os.path.join(models_path_prefix, "openbmb/MiniCPM-Llama3-V-2_5"),
822
        extras={
823
824
825
            "2.6": os.path.join(models_path_prefix, "openbmb/MiniCPM-V-2_6"),
            "4.0": os.path.join(models_path_prefix, "openbmb/MiniCPM-V-4"),
            "4.5": os.path.join(models_path_prefix, "openbmb/MiniCPM-V-4_5"),
826
827
828
829
        },
        trust_remote_code=True,
    ),
    "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo(
830
        os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-VL-01"),
831
832
833
        trust_remote_code=True,
    ),
    "Mistral3ForConditionalGeneration": _HfExamplesInfo(
834
835
        os.path.join(models_path_prefix, "mistralai/Mistral-Small-3.1-24B-Instruct-2503"),
        extras={"fp8": os.path.join(models_path_prefix, "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic")},
836
837
    ),
    "MolmoForCausalLM": _HfExamplesInfo(
838
        os.path.join(models_path_prefix, "allenai/Molmo-7B-D-0924"),
839
        max_transformers_version="4.48",
840
841
842
        transformers_version_reason={
            "vllm": "Incorrectly-detected `tensorflow` import from processor."
        },
843
        extras={"olmo": os.path.join(models_path_prefix, "allenai/Molmo-7B-O-0924")},
844
845
        trust_remote_code=True,
    ),
846
    "Molmo2ForConditionalGeneration": _HfExamplesInfo(
847
        os.path.join(models_path_prefix, "allenai/Molmo2-8B"),
848
849
850
851
852
853
        extras={"olmo": "allenai/Molmo2-O-7B"},
        min_transformers_version="4.51",
        trust_remote_code=True,
        # required by current PrefixLM implementation
        max_num_batched_tokens=31872,
    ),
854
    "NVLM_D": _HfExamplesInfo(os.path.join(models_path_prefix, "nvidia/NVLM-D-72B"), trust_remote_code=True),
855
    "Llama_Nemotron_Nano_VL": _HfExamplesInfo(
856
        os.path.join(models_path_prefix, "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"),
857
858
859
        trust_remote_code=True,
    ),
    "NemotronH_Nano_VL_V2": _HfExamplesInfo(
860
        os.path.join(models_path_prefix, "nano_vl_dummy"), is_available_online=False, trust_remote_code=True
861
    ),
Zero's avatar
Zero committed
862
    "OpenCUAForConditionalGeneration": _HfExamplesInfo(
863
        os.path.join(models_path_prefix, "xlangai/OpenCUA-7B"), trust_remote_code=True
Zero's avatar
Zero committed
864
    ),
865
    "Ovis": _HfExamplesInfo(
866
        os.path.join(models_path_prefix, "AIDC-AI/Ovis2-1B"),
867
868
        trust_remote_code=True,
        max_transformers_version="4.53",
869
        transformers_version_reason={"hf": "HF model is not compatible"},
870
        extras={
871
872
            "1.6-llama": os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Llama3.2-3B"),
            "1.6-gemma": os.path.join(models_path_prefix, "AIDC-AI/Ovis1.6-Gemma2-9B"),
873
874
        },
    ),
875
    "Ovis2_5": _HfExamplesInfo(os.path.join(models_path_prefix, "AIDC-AI/Ovis2.5-2B"), trust_remote_code=True),
876
    "PaddleOCRVLForConditionalGeneration": _HfExamplesInfo(
877
       os.path.join(models_path_prefix,  "PaddlePaddle/PaddleOCR-VL"),
878
879
        trust_remote_code=True,
    ),
880
    "PaliGemmaForConditionalGeneration": _HfExamplesInfo(
881
882
        os.path.join(models_path_prefix, "google/paligemma-3b-mix-224"),
        extras={"v2": os.path.join(models_path_prefix, "google/paligemma2-3b-ft-docci-448")},
883
884
    ),
    "Phi3VForCausalLM": _HfExamplesInfo(
885
        os.path.join(models_path_prefix, "microsoft/Phi-3-vision-128k-instruct"),
886
887
        trust_remote_code=True,
        max_transformers_version="4.48",
888
889
890
        transformers_version_reason={
            "hf": "HF model use deprecated imports which have been removed."
        },  # noqa: E501
891
        extras={"phi3.5": os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct")},
892
893
    ),
    "Phi4MMForCausalLM": _HfExamplesInfo(
894
        os.path.join(models_path_prefix, "microsoft/Phi-4-multimodal-instruct"), trust_remote_code=True
895
896
    ),
    "PixtralForConditionalGeneration": _HfExamplesInfo(
897
        os.path.join(models_path_prefix, "mistralai/Pixtral-12B-2409"),
898
        extras={
899
900
            "mistral-large-3": os.path.join(models_path_prefix, "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4"),
            "ministral-3": os.path.join(models_path_prefix, "mistralai/Ministral-3-3B-Instruct-2512"),
901
        },
902
903
904
        tokenizer_mode="mistral",
    ),
    "QwenVLForConditionalGeneration": _HfExamplesInfo(
905
906
        os.path.join(models_path_prefix, "Qwen/Qwen-VL"),
        extras={"chat": os.path.join(models_path_prefix, "Qwen/Qwen-VL-Chat")},
907
        trust_remote_code=True,
908
        max_transformers_version="4.53.3",
909
910
911
        transformers_version_reason={
            "hf": "HF model uses deprecated imports which have been removed."
        },  # noqa: E501
912
913
914
        hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
    ),
    "Qwen2AudioForConditionalGeneration": _HfExamplesInfo(
915
        os.path.join(models_path_prefix, "Qwen/Qwen2-Audio-7B-Instruct")
916
    ),
917
    "Qwen2VLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct")),
918
    "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo(
919
        os.path.join(models_path_prefix, "Qwen/Qwen2.5-VL-3B-Instruct"),
920
921
        max_model_len=4096,
    ),
zhuwenwen's avatar
zhuwenwen committed
922
    "Qwen2_5OmniModel": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-3B")),
923
    "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen2.5-Omni-7B-AWQ")),
924
    "Qwen3VLForConditionalGeneration": _HfExamplesInfo(
925
        os.path.join(models_path_prefix, "Qwen/Qwen3-VL-4B-Instruct"),
926
927
928
929
        max_model_len=4096,
        min_transformers_version="4.57",
    ),
    "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo(
930
       os.path.join(models_path_prefix,  "Qwen/Qwen3-VL-30B-A3B-Instruct"),
931
932
933
        max_model_len=4096,
        min_transformers_version="4.57",
    ),
934
    "Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo(
935
        os.path.join(models_path_prefix, "Qwen/Qwen3-Omni-30B-A3B-Instruct"),
936
937
938
        max_model_len=4096,
        min_transformers_version="4.57",
    ),
939
    "RForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "YannQi/R-4B"), trust_remote_code=True),
940
    "SkyworkR1VChatModel": _HfExamplesInfo(
941
        os.path.join(models_path_prefix, "Skywork/Skywork-R1V-38B"), trust_remote_code=True
942
943
    ),
    "SmolVLMForConditionalGeneration": _HfExamplesInfo(
944
        os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct")
945
946
    ),
    "Step3VLForConditionalGeneration": _HfExamplesInfo(
947
        os.path.join(models_path_prefix, "stepfun-ai/step3"), trust_remote_code=True
948
    ),
ltd0924's avatar
ltd0924 committed
949
950
951
    "StepVLForConditionalGeneration": _HfExamplesInfo(
        "stepfun-ai/Step3-VL-10B", trust_remote_code=True
    ),
952
    "UltravoxModel": _HfExamplesInfo(
953
        os.path.join(models_path_prefix, "fixie-ai/ultravox-v0_5-llama-3_2-1b"),
954
955
        trust_remote_code=True,
    ),
956
    "TarsierForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "omni-research/Tarsier-7b")),
957
    "Tarsier2ForConditionalGeneration": _HfExamplesInfo(
958
        os.path.join(models_path_prefix, "omni-research/Tarsier2-Recap-7b"),
959
        hf_overrides={
960
            "architectures": [os.path.join(models_path_prefix, "Tarsier2ForConditionalGeneration")],
961
962
            "model_type": "tarsier2",
        },
963
    ),
964
965
966
967
968
    "VoxtralForConditionalGeneration": _HfExamplesInfo(
        "mistralai/Voxtral-Mini-3B-2507",
        # disable this temporarily until we support HF format
        is_available_online=False,
    ),
Patrick von Platen's avatar
Patrick von Platen committed
969
970
971
972
973
    "VoxtralStreamingGeneration": _HfExamplesInfo(
        "<place-holder>",
        # disable this temporarily until we support HF format
        is_available_online=False,
    ),
974
    # [Encoder-decoder]
975
976
977
    "NemotronParseForConditionalGeneration": _HfExamplesInfo(
        "nvidia/NVIDIA-Nemotron-Parse-v1.1", trust_remote_code=True
    ),
978
    "WhisperForConditionalGeneration": _HfExamplesInfo(
979
980
        os.path.join(models_path_prefix, "openai/whisper-large-v3-turbo"),
        extras={"v3": os.path.join(models_path_prefix, "openai/whisper-large-v3")},
981
    ),
982
    # [Cross-encoder]
983
    "JinaVLForRanking": _HfExamplesInfo(os.path.join(models_path_prefix, "jinaai/jina-reranker-m0")),
984
985
}

986

987
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {
988
    "MedusaModel": _HfExamplesInfo(
989
        os.path.join(models_path_prefix, "JackFram/llama-68m"), speculative_model=os.path.join(models_path_prefix, "abhigoyal/vllm-medusa-llama-68m-random")
990
    ),
991
992
    # Temporarily disabled.
    # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
993
994
995
996
997
    # "MLPSpeculatorPreTrainedModel": _HfExamplesInfo(
    #     "JackFram/llama-160m",
    #     speculative_model="ibm-ai-platform/llama-160m-accelerator"
    # ),
    "DeepSeekMTPModel": _HfExamplesInfo(
998
999
        os.path.join(models_path_prefix, "luccafong/deepseek_mtp_main_random"),
        speculative_model=os.path.join(models_path_prefix, "luccafong/deepseek_mtp_draft_random"),
1000
1001
1002
        trust_remote_code=True,
    ),
    "EagleDeepSeekMTPModel": _HfExamplesInfo(
1003
1004
        os.path.join(models_path_prefix, "eagle618/deepseek-v3-random"),
        speculative_model=os.path.join(models_path_prefix, "eagle618/eagle-deepseek-v3-random"),
1005
1006
1007
        trust_remote_code=True,
    ),
    "EagleLlamaForCausalLM": _HfExamplesInfo(
1008
       os.path.join(models_path_prefix,  "meta-llama/Meta-Llama-3-8B-Instruct"),
1009
        trust_remote_code=True,
1010
1011
        speculative_model=os.path.join(models_path_prefix, "yuhuili/EAGLE-LLaMA3-Instruct-8B"),
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Meta-Llama-3-8B-Instruct"),
1012
1013
    ),
    "Eagle3LlamaForCausalLM": _HfExamplesInfo(
1014
        os.path.join(models_path_prefix, "meta-llama/Llama-3.1-8B-Instruct"),
1015
        trust_remote_code=True,
1016
1017
        speculative_model=os.path.join(models_path_prefix, "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"),
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Llama-3.1-8B-Instruct"),
1018
1019
1020
        use_original_num_layers=True,
        max_model_len=10240,
    ),
1021
    "EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
1022
        os.path.join(models_path_prefix, "mistralai/Mistral-Large-3-675B-Instruct-2512"),
1023
        speculative_model=os.path.join(models_path_prefix, "mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle"),
1024
        # TODO: revert once figuring out OOM in CI
1025
1026
        is_available_online=False,
    ),
1027
    "LlamaForCausalLMEagle3": _HfExamplesInfo(
1028
        os.path.join(models_path_prefix, "Qwen/Qwen3-8B"),
1029
        trust_remote_code=True,
1030
1031
        speculative_model=os.path.join(models_path_prefix, "AngelSlim/Qwen3-8B_eagle3"),
        tokenizer=os.path.join(models_path_prefix, "Qwen/Qwen3-8B"),
1032
1033
        use_original_num_layers=True,
    ),
zhiweiz's avatar
zhiweiz committed
1034
    "EagleLlama4ForCausalLM": _HfExamplesInfo(
1035
        os.path.join(models_path_prefix, "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"),
zhiweiz's avatar
zhiweiz committed
1036
        trust_remote_code=True,
1037
        speculative_model=os.path.join(models_path_prefix, "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"),
1038
        tokenizer=os.path.join(models_path_prefix, "meta-llama/Llama-4-Scout-17B-16E-Instruct"),
1039
1040
    ),
    "EagleMiniCPMForCausalLM": _HfExamplesInfo(
1041
        os.path.join(models_path_prefix, "openbmb/MiniCPM-1B-sft-bf16"),
1042
        trust_remote_code=True,
1043
1044
1045
        speculative_model=os.path.join(models_path_prefix, "openbmb/MiniCPM-2B-sft-bf16"),
        speculative_method=os.path.join(models_path_prefix, "eagle"),
        tokenizer=os.path.join(models_path_prefix, "openbmb/MiniCPM-2B-sft-bf16"),
1046
1047
    ),
    "ErnieMTPModel": _HfExamplesInfo(
1048
        os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT"),
1049
        trust_remote_code=True,
1050
        speculative_model=os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT"),
1051
    ),
Kyungmin Lee's avatar
Kyungmin Lee committed
1052
1053
1054
1055
1056
    "ExaoneMoeMTP": _HfExamplesInfo(
        "LGAI-EXAONE/K-EXAONE-236B-A23B",
        speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
        min_transformers_version="5.0.0",
    ),
1057
    "Glm4MoeMTPModel": _HfExamplesInfo(
1058
        os.path.join(models_path_prefix, "zai-org/GLM-4.5"),
1059
1060
        speculative_model="zai-org/GLM-4.5",
    ),
1061
1062
1063
1064
1065
    "Glm4MoeLiteMTPModel": _HfExamplesInfo(
        "zai-org/GLM-4.7-Flash",
        speculative_model="zai-org/GLM-4.7-Flash",
        is_available_online=False,
    ),
XuruiYang's avatar
XuruiYang committed
1066
    "LongCatFlashMTPModel": _HfExamplesInfo(
1067
        os.path.join(models_path_prefix, os.path.join(models_path_prefix, "meituan-longcat/LongCat-Flash-Chat")),
XuruiYang's avatar
XuruiYang committed
1068
        trust_remote_code=True,
1069
        speculative_model=os.path.join(models_path_prefix, "meituan-longcat/LongCat-Flash-Chat"),
1070
1071
    ),
    "MiMoMTPModel": _HfExamplesInfo(
1072
        os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL"),
1073
        trust_remote_code=True,
1074
        speculative_model=os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL"),
1075
    ),
1076
    "Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo(
1077
1078
        os.path.join(models_path_prefix, "Qwen/Qwen2.5-VL-7B-Instruct"),
        speculative_model=os.path.join(models_path_prefix, "Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
1079
    ),
1080
    "Eagle3Qwen3vlForCausalLM": _HfExamplesInfo(
1081
1082
        os.path.join(models_path_prefix, "Qwen/Qwen3-VL-8B-Instruct"),
        speculative_model=os.path.join(models_path_prefix, "taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3"),
1083
    ),
1084
    "Qwen3NextMTP": _HfExamplesInfo(
1085
        os.path.join(models_path_prefix, "Qwen/Qwen3-Next-80B-A3B-Instruct"), min_transformers_version="4.56.3"
1086
    ),
1087
1088
}

1089
_TRANSFORMERS_BACKEND_MODELS = {
1090
    "TransformersEmbeddingModel": _HfExamplesInfo(
1091
        os.path.join(models_path_prefix, "BAAI/bge-base-en-v1.5"), min_transformers_version="5.0.0.dev"
1092
1093
    ),
    "TransformersForSequenceClassification": _HfExamplesInfo(
1094
        os.path.join(models_path_prefix, "papluca/xlm-roberta-base-language-detection"),
1095
        min_transformers_version="5.0.0.dev",
1096
1097
    ),
    "TransformersForCausalLM": _HfExamplesInfo(
1098
        os.path.join(models_path_prefix, "hmellor/Ilama-3.2-1B"), trust_remote_code=True
1099
    ),
1100
    "TransformersMultiModalForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "BAAI/Emu3-Chat-hf")),
1101
    "TransformersMoEForCausalLM": _HfExamplesInfo(
1102
        os.path.join(models_path_prefix, "allenai/OLMoE-1B-7B-0924"), min_transformers_version="5.0.0.dev"
1103
    ),
1104
    "TransformersMultiModalMoEForCausalLM": _HfExamplesInfo(
1105
        os.path.join(models_path_prefix, "Qwen/Qwen3-VL-30B-A3B-Instruct"), min_transformers_version="5.0.0.dev"
1106
1107
    ),
    "TransformersMoEEmbeddingModel": _HfExamplesInfo(
1108
        os.path.join(models_path_prefix, "Qwen/Qwen3-30B-A3B"), min_transformers_version="5.0.0.dev"
1109
1110
    ),
    "TransformersMoEForSequenceClassification": _HfExamplesInfo(
1111
        os.path.join(models_path_prefix, "Qwen/Qwen3-30B-A3B"), min_transformers_version="5.0.0.dev"
1112
    ),
1113
    "TransformersMultiModalEmbeddingModel": _HfExamplesInfo(os.path.join(models_path_prefix, "google/gemma-3-4b-it")),
1114
    "TransformersMultiModalForSequenceClassification": _HfExamplesInfo(
1115
        os.path.join(models_path_prefix, "google/gemma-3-4b-it")
1116
    ),
1117
1118
}

1119
1120
1121
_EXAMPLE_MODELS = {
    **_TEXT_GENERATION_EXAMPLE_MODELS,
    **_EMBEDDING_EXAMPLE_MODELS,
1122
    **_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS,
1123
1124
    **_MULTIMODAL_EXAMPLE_MODELS,
    **_SPECULATIVE_DECODING_EXAMPLE_MODELS,
1125
    **_TRANSFORMERS_BACKEND_MODELS,
1126
1127
1128
1129
1130
1131
1132
1133
1134
}


class HfExampleModels:
    def __init__(self, hf_models: Mapping[str, _HfExamplesInfo]) -> None:
        super().__init__()

        self.hf_models = hf_models

1135
    def get_supported_archs(self) -> Set[str]:
1136
1137
1138
        return self.hf_models.keys()

    def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
1139
1140
1141
        try:
            return self.hf_models[model_arch]
        except KeyError:
1142
1143
1144
            raise ValueError(
                f"No example model defined for {model_arch}; please update this file."
            ) from None
1145

1146
1147
1148
1149
1150
    def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
        for info in self.hf_models.values():
            if info.default == model_id:
                return info

1151
1152
1153
1154
1155
        # Fallback to extras
        for info in self.hf_models.values():
            if any(extra == model_id for extra in info.extras.values()):
                return info

1156
1157
1158
        raise ValueError(
            f"No example model defined for {model_id}; please update this file."
        )
1159

1160

Patrick von Platen's avatar
Patrick von Platen committed
1161
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
1162
AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS)