config.py 44.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import fnmatch
5
import json
6
import os
7
import time
8
from collections.abc import Callable
9
from dataclasses import asdict
10
from functools import cache, partial
11
from importlib.metadata import version
12
from pathlib import Path
13
from typing import Any, Literal, TypeAlias, TypeVar
Jasmond L's avatar
Jasmond L committed
14

Joe Runde's avatar
Joe Runde committed
15
import huggingface_hub
16
17
18
19
20
from huggingface_hub import (
    get_safetensors_metadata,
    hf_hub_download,
    try_to_load_from_cache,
)
21
from huggingface_hub import list_repo_files as hf_list_repo_files
22
23
24
25
26
27
28
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    LocalEntryNotFoundError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
)
29
from packaging.version import Version
30
from transformers import GenerationConfig, PretrainedConfig
31
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
32
from transformers.models.auto.image_processing_auto import get_image_processor_config
33
34
35
36
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_MAPPING_NAMES,
)
37
from transformers.models.auto.tokenization_auto import get_tokenizer_config
38
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
39

40
from vllm import envs
41
from vllm.logger import init_logger
42
from vllm.transformers_utils.config_parser_base import ConfigParserBase
43
44
45
46
from vllm.transformers_utils.utils import (
    check_gguf_file,
    parse_safetensors_file_metadata,
)
47

48
if envs.VLLM_USE_MODELSCOPE:
49
50
51
    from modelscope import AutoConfig
else:
    from transformers import AutoConfig
52

53
54
MISTRAL_CONFIG_NAME = "params.json"

55
56
logger = init_logger(__name__)

57

58
def _get_hf_token() -> str | None:
59
60
61
    """
    Get the HuggingFace token from environment variable.

62
    Returns None if the token is not set, is an empty string,
63
64
65
66
    or contains only whitespace.
    This follows the same pattern as huggingface_hub library which
    treats empty string tokens as None to avoid authentication errors.
    """
67
    token = os.getenv("HF_TOKEN")
68
69
70
71
72
    if token and token.strip():
        return token
    return None


73
74
class LazyConfigDict(dict):
    def __getitem__(self, key):
75
76
77
        if isinstance(value := super().__getitem__(key), type):
            return value

78
        import vllm.transformers_utils.configs as configs
79

80
        return getattr(configs, value)
81
82
83


_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
84
    afmoe="AfmoeConfig",
85
86
    chatglm="ChatGLMConfig",
    deepseek_vl_v2="DeepseekVLV2Config",
87
    deepseek_v32="DeepseekV3Config",
88
    flex_olmo="FlexOlmoConfig",
89
    hunyuan_vl="HunYuanVLConfig",
90
    kimi_linear="KimiLinearConfig",
91
92
93
94
95
96
    kimi_vl="KimiVLConfig",
    RefinedWeb="RWConfig",  # For tiiuae/falcon-40b(-instruct)
    RefinedWebModel="RWConfig",  # For tiiuae/falcon-7b(-instruct)
    jais="JAISConfig",
    mlp_speculator="MLPSpeculatorConfig",
    medusa="MedusaConfig",
97
    midashenglm="MiDashengLMConfig",
98
99
100
    eagle="EAGLEConfig",
    speculators="SpeculatorsConfig",
    nemotron="NemotronConfig",
101
    olmo3="Olmo3Config",
102
103
104
105
    ovis="OvisConfig",
    ultravox="UltravoxConfig",
    step3_vl="Step3VLConfig",
    step3_text="Step3TextConfig",
106
    qwen3_next="Qwen3NextConfig",
Paul Pak's avatar
Paul Pak committed
107
    lfm2_moe="Lfm2MoeConfig",
108
)
109

110
111
112
113
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
    "llm_config": "text_config",
}

114
_AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
115
    "internvl_chat": {"has_no_defaults_at_init": True},
116
    "Llama_Nemotron_Nano_VL": {"attn_implementation": "eager"},
117
    "NVLM_D": {"has_no_defaults_at_init": True},
118
119
}

120

121
class HFConfigParser(ConfigParserBase):
122
123
    def parse(
        self,
124
        model: str | Path,
125
        trust_remote_code: bool,
126
127
        revision: str | None = None,
        code_revision: str | None = None,
128
129
        **kwargs,
    ) -> tuple[dict, PretrainedConfig]:
130
131
132
133
134
135
136
137
138
139
140
        kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
        config_dict, _ = PretrainedConfig.get_config_dict(
            model,
            revision=revision,
            code_revision=code_revision,
            token=_get_hf_token(),
            **kwargs,
        )
        # Use custom model class if it's in our registry
        model_type = config_dict.get("model_type")
        if model_type is None:
141
142
143
144
145
            model_type = (
                "speculators"
                if config_dict.get("speculators_config") is not None
                else model_type
            )
146
147
148
149
150
151
152
153
154
155
156
157

        if model_type in _CONFIG_REGISTRY:
            config_class = _CONFIG_REGISTRY[model_type]
            config = config_class.from_pretrained(
                model,
                revision=revision,
                code_revision=code_revision,
                token=_get_hf_token(),
                **kwargs,
            )
        else:
            try:
158
                kwargs = _maybe_update_auto_config_kwargs(kwargs, model_type=model_type)
159
160
161
162
163
164
165
166
167
                config = AutoConfig.from_pretrained(
                    model,
                    trust_remote_code=trust_remote_code,
                    revision=revision,
                    code_revision=code_revision,
                    token=_get_hf_token(),
                    **kwargs,
                )
            except ValueError as e:
168
169
170
171
                if (
                    not trust_remote_code
                    and "requires you to execute the configuration file" in str(e)
                ):
172
173
174
175
176
                    err_msg = (
                        "Failed to load the model config. If the model "
                        "is a custom model not yet available in the "
                        "HuggingFace transformers library, consider setting "
                        "`trust_remote_code=True` in LLM or using the "
177
178
                        "`--trust-remote-code` flag in the CLI."
                    )
179
180
181
182
183
184
185
186
                    raise RuntimeError(err_msg) from e
                else:
                    raise e
        config = _maybe_remap_hf_config_attrs(config)
        return config_dict, config


class MistralConfigParser(ConfigParserBase):
187
188
    def parse(
        self,
189
        model: str | Path,
190
        trust_remote_code: bool,
191
192
        revision: str | None = None,
        code_revision: str | None = None,
193
194
        **kwargs,
    ) -> tuple[dict, PretrainedConfig]:
195
196
197
        # This function loads a params.json config which
        # should be used when loading models in mistral format
        config_dict = _download_mistral_config_file(model, revision)
198
199
200
        if (
            max_position_embeddings := config_dict.get("max_position_embeddings")
        ) is None:
201
            max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
202
203
                model, revision, **kwargs
            )
204
205
206
207
            config_dict["max_position_embeddings"] = max_position_embeddings

        from vllm.transformers_utils.configs.mistral import adapt_config_dict

208
209
210
211
212
213
214
215
216
217
218
219
220
        # Get missing fields from HF config if available
        try:
            hf_config_dict, _ = PretrainedConfig.get_config_dict(
                model,
                revision=revision,
                code_revision=code_revision,
                token=_get_hf_token(),
                **kwargs,
            )
        except OSError:  # Not found
            hf_config_dict = {}

        config = adapt_config_dict(config_dict, defaults=hf_config_dict)
221
222
223

        # Mistral configs may define sliding_window as list[int]. Convert it
        # to int and add the layer_types list[str] to make it HF compatible
224
225
226
        if (sliding_window := getattr(config, "sliding_window", None)) and isinstance(
            sliding_window, list
        ):
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            pattern_repeats = config.num_hidden_layers // len(sliding_window)
            layer_types = sliding_window * pattern_repeats
            config.layer_types = [
                "full_attention" if layer_type is None else "sliding_attention"
                for layer_type in layer_types
            ]
            config.sliding_window = next(filter(None, sliding_window), None)

        return config_dict, config


_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = {
    "hf": HFConfigParser,
    "mistral": MistralConfigParser,
}

ConfigFormat = Literal[
    "auto",
    "hf",
    "mistral",
]


def get_config_parser(config_format: str) -> ConfigParserBase:
    """Get the config parser for a given config format."""
    if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER:
        raise ValueError(f"Unknown config format `{config_format}`.")
    return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]()


def register_config_parser(config_format: str):
    """Register a customized vllm config parser.
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
     When a config format is not supported by vllm, you can register a customized
    config parser to support it.
     Args:
         config_format (str): The config parser format name.
     Examples:

         >>> from vllm.transformers_utils.config import (get_config_parser,
                                                         register_config_parser)
         >>> from vllm.transformers_utils.config_parser_base import ConfigParserBase
         >>>
         >>> @register_config_parser("custom_config_parser")
         ... class CustomConfigParser(ConfigParserBase):
         ...     def parse(
         ...         self,
         ...         model: Union[str, Path],
         ...         trust_remote_code: bool,
275
276
         ...         revision: str | None = None,
         ...         code_revision: str | None = None,
277
278
279
280
281
282
         ...         **kwargs,
         ...     ) -> tuple[dict, PretrainedConfig]:
         ...         raise NotImplementedError
         >>>
         >>> type(get_config_parser("custom_config_parser"))
         <class 'CustomConfigParser'>
283
284
285
286
287
288
    """  # noqa: E501

    def _wrapper(config_parser_cls):
        if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER:
            logger.warning(
                "Config format `%s` is already registered, and will be "
289
290
291
292
                "overwritten by the new parser class `%s`.",
                config_format,
                config_parser_cls,
            )
293
        if not issubclass(config_parser_cls, ConfigParserBase):
294
295
296
            raise ValueError(
                "The config parser must be a subclass of `ConfigParserBase`."
            )
297
        _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls
298
299
300
301
302
        logger.info(
            "Registered config parser `%s` with config format `%s`",
            config_parser_cls,
            config_format,
        )
303
304
305
        return config_parser_cls

    return _wrapper
306
307


308
309
310
311
312
313
314
315
316
_R = TypeVar("_R")


def with_retry(
    func: Callable[[], _R],
    log_msg: str,
    max_retries: int = 2,
    retry_delay: int = 2,
) -> _R:
317
318
319
320
321
322
323
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            if attempt == max_retries - 1:
                logger.error("%s: %s", log_msg, e)
                raise
324
325
326
            logger.error(
                "%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries
            )
327
328
329
            time.sleep(retry_delay)
            retry_delay *= 2

330
331
    raise AssertionError("Should not be reached")

332
333
334
335
336
337

# @cache doesn't cache exceptions
@cache
def list_repo_files(
    repo_id: str,
    *,
338
339
340
    revision: str | None = None,
    repo_type: str | None = None,
    token: str | bool | None = None,
341
) -> list[str]:
342
343
344
345
346
    def lookup_files() -> list[str]:
        # directly list files if model is local
        if (local_path := Path(repo_id)).exists():
            return [
                str(file.relative_to(local_path))
347
348
                for file in local_path.rglob("*")
                if file.is_file()
349
350
            ]
        # if model is remote, use hf_hub api to list files
351
        try:
352
            if envs.VLLM_USE_MODELSCOPE:
353
354
355
356
357
358
359
360
361
362
                from vllm.transformers_utils.utils import modelscope_list_repo_files

                return modelscope_list_repo_files(
                    repo_id,
                    revision=revision,
                    token=os.getenv("MODELSCOPE_API_TOKEN", None),
                )
            return hf_list_repo_files(
                repo_id, revision=revision, repo_type=repo_type, token=token
            )
363
364
365
366
367
368
369
370
371
        except huggingface_hub.errors.OfflineModeIsEnabled:
            # Don't raise in offline mode,
            # all we know is that we don't have this
            # file cached.
            return []

    return with_retry(lookup_files, "Error retrieving file list")


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
def list_filtered_repo_files(
    model_name_or_path: str,
    allow_patterns: list[str],
    revision: str | None = None,
    repo_type: str | None = None,
    token: str | bool | None = None,
) -> list[str]:
    try:
        all_files = list_repo_files(
            repo_id=model_name_or_path,
            revision=revision,
            token=token,
            repo_type=repo_type,
        )
    except Exception:
        logger.error(
            "Error retrieving file list. Please ensure your `model_name_or_path`"
            "`repo_type`, `token` and `revision` arguments are correctly set. "
            "Returning an empty list."
        )
        return []

    file_list = []
    # Filter patterns on filenames
    for pattern in allow_patterns:
        file_list.extend(
            [
                file
                for file in all_files
                if fnmatch.fnmatch(os.path.basename(file), pattern)
            ]
        )
    return file_list


407
408
409
410
def file_exists(
    repo_id: str,
    file_name: str,
    *,
411
412
413
    repo_type: str | None = None,
    revision: str | None = None,
    token: str | bool | None = None,
414
) -> bool:
415
416
417
    file_list = list_repo_files(
        repo_id, repo_type=repo_type, revision=revision, token=token
    )
418
419
420
421
    return file_name in file_list


# In offline mode the result can be a false negative
422
def file_or_path_exists(
423
    model: str | Path, config_name: str, revision: str | None
424
) -> bool:
425
426
    if (local_path := Path(model)).exists():
        return (local_path / config_name).is_file()
427

Joe Runde's avatar
Joe Runde committed
428
    # Offline mode support: Check if config file is cached already
429
430
431
    cached_filepath = try_to_load_from_cache(
        repo_id=model, filename=config_name, revision=revision
    )
Joe Runde's avatar
Joe Runde committed
432
433
434
435
436
437
    if isinstance(cached_filepath, str):
        # The config file exists in cache- we can continue trying to load
        return True

    # NB: file_exists will only check for the existence of the config file on
    # hf_hub. This will fail in offline mode.
438
439

    # Call HF to check if the file exists
440
441
442
    return file_exists(
        str(model), config_name, revision=revision, token=_get_hf_token()
    )
443
444


445
446
447
448
449
450
451
452
453
454
def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> None:
    """Some models may have no rope_theta in their config but still use RoPE.
    This function sets a default rope_theta if it's missing."""
    if getattr(config, "rope_parameters", None) is None:
        config.rope_parameters = {"rope_type": "default"}
    if "rope_theta" not in config.rope_parameters:
        config.rope_parameters["rope_theta"] = default_theta


def patch_rope_parameters(config: PretrainedConfig) -> None:
455
    """Provide backwards compatibility for RoPE."""
456
457
458
    # Retrieve rope_parameters differently based on Transformers version
    if Version(version("transformers")) >= Version("5.0.0.dev0"):
        from transformers.modeling_rope_utils import RopeParameters
459

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr(
            config, "rope_parameters", None
        )
    elif hasattr(config, "rope_parameters"):
        # We are in Transformers v4 and rope_parameters
        # has already been patched for this config
        return
    else:
        # Convert Transformers v4 rope_theta and rope_scaling into rope_parameters
        rope_theta: float | None = getattr(config, "rope_theta", None)
        rope_scaling: dict | None = getattr(config, "rope_scaling", None)
        rope_parameters = rope_scaling
        # Move rope_theta into rope_parameters
        if rope_theta is not None:
            rope_parameters = rope_parameters or {"rope_type": "default"}
            rope_parameters["rope_theta"] = rope_theta
        # Add original_max_position_embeddings if present
        if rope_parameters and (
            ompe := getattr(config, "original_max_position_embeddings", None)
        ):
            rope_parameters["original_max_position_embeddings"] = ompe
        # Write back to config
        config.rope_parameters = rope_parameters

    # No RoPE parameters to patch
    if rope_parameters is None:
        return

    # Handle nested rope_parameters in interleaved sliding attention models
    if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
        for rope_parameters_layer_type in rope_parameters.values():
            patch_rope_parameters_dict(rope_parameters_layer_type)
    else:
        patch_rope_parameters_dict(rope_parameters)
494
495


496
497
498
499
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
    if "rope_type" in rope_parameters and "type" in rope_parameters:
        rope_type = rope_parameters["rope_type"]
        rope_type_legacy = rope_parameters["type"]
500
501
502
503
        if rope_type != rope_type_legacy:
            raise ValueError(
                f"Found conflicts between 'rope_type={rope_type}' (modern "
                f"field) and 'type={rope_type_legacy}' (legacy field). "
504
505
                "You should only specify one of them."
            )
506

507
508
    if "rope_type" not in rope_parameters and "type" in rope_parameters:
        rope_parameters["rope_type"] = rope_parameters["type"]
509
510
        logger.info("Replacing legacy 'type' key with 'rope_type'")

511
512
    if "rope_type" not in rope_parameters:
        raise ValueError("rope_parameters should have a 'rope_type' key")
513

514
515
    if rope_parameters["rope_type"] == "su":
        rope_parameters["rope_type"] = "longrope"
516
        logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
517
518
519
    elif rope_parameters["rope_type"] == "mrope":
        assert "mrope_section" in rope_parameters
        rope_parameters["rope_type"] = "default"
520
521
522
        logger.warning("Replacing legacy rope_type 'mrope' with 'default'")


523
def _uses_mrope(config: PretrainedConfig) -> bool:
524
525
    rope_parameters = getattr(config, "rope_parameters", None)
    if rope_parameters is None:
526
527
        return False

528
    return "mrope_section" in rope_parameters
529
530


531
532
def uses_mrope(config: PretrainedConfig) -> bool:
    """Detect if the model with this config uses M-ROPE."""
533
534
535
536
537
    return (
        _uses_mrope(config)
        or _uses_mrope(config.get_text_config())
        or thinker_uses_mrope(config)
    )
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552


def thinker_uses_mrope(config: PretrainedConfig) -> bool:
    """Detect if the model contains a thinker config and it uses M-ROPE."""
    thinker_config = getattr(config, "thinker_config", None)
    if thinker_config is None:
        return False

    thinker_text_config = getattr(thinker_config, "text_config", None)
    if thinker_text_config is None:
        return False

    return uses_mrope(thinker_text_config)


553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
def uses_xdrope_dim(config: PretrainedConfig) -> int:
    """Detect if the model with this config uses XD-ROPE."""
    xdrope_section = getattr(config, "xdrope_section", None)
    if xdrope_section is not None and isinstance(xdrope_section, list):
        return len(xdrope_section)
    rope_scaling = getattr(config, "rope_scaling", None)
    if rope_scaling is None:
        return 0

    if isinstance(rope_scaling, dict) and "xdrope_section" in rope_scaling:
        xdrope_section = rope_scaling["xdrope_section"]
        if xdrope_section is not None and isinstance(xdrope_section, list):
            return len(xdrope_section)

    return 0


570
571
572
def is_encoder_decoder(config: PretrainedConfig) -> bool:
    """Detect if the model with this config is used as an encoder/decoder."""

573
574
575
    def _is_encoder_decoder(config: PretrainedConfig) -> bool:
        return getattr(config, "is_encoder_decoder", False)

576
    return _is_encoder_decoder(config) or _is_encoder_decoder(config.get_text_config())
577
578


579
580
581
582
583
584
def is_interleaved(config: PretrainedConfig) -> bool:
    """
    Detect if the model with this config is used with interleaved attention.
    """
    text_config = config.get_text_config()
    if layer_types := getattr(text_config, "layer_types", None):
585
        return len(set(layer_types)) > 1
586
587
588
    return False


589
590
591
592
593
594
595
596
597
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
    """
    Update kwargs for AutoConfig initialization based on model_type
    """
    if model_type in _AUTO_CONFIG_KWARGS_OVERRIDES:
        kwargs.update(_AUTO_CONFIG_KWARGS_OVERRIDES[model_type])
    return kwargs


598
599
600
601
602
603
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
    """Remap config attributes to match the expected names."""
    for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
        if hasattr(config, old_attr):
            if not hasattr(config, new_attr):
                config.update({new_attr: getattr(config, old_attr)})
604
            logger.debug("Remapped config attribute '%s' to '%s'", old_attr, new_attr)
605
606
607
    return config


608
def maybe_override_with_speculators(
609
610
611
    model: str,
    tokenizer: str,
    trust_remote_code: bool,
612
613
    revision: str | None = None,
    vllm_speculative_config: dict[str, Any] | None = None,
614
    **kwargs,
615
) -> tuple[str, str, dict[str, Any] | None]:
616
    """
617
618
619
620
621
622
623
624
625
626
627
628
629
630
    Resolve model configuration when speculators are detected.

    Checks if the provided model is a speculators model and if so, extracts
    the target model configuration and builds the speculative config.

    Args:
        model: Model name or path
        tokenizer: Tokenizer name or path
        trust_remote_code: Whether to trust remote code
        revision: Model revision
        vllm_speculative_config: Existing vLLM speculative config

    Returns:
        Tuple of (resolved_model, resolved_tokenizer, speculative_config)
631
    """
632
633
634
635
636
637
    is_gguf = check_gguf_file(model)
    if is_gguf:
        kwargs["gguf_file"] = Path(model).name
        gguf_model_repo = Path(model).parent
    else:
        gguf_model_repo = None
638
    kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
639
    config_dict, _ = PretrainedConfig.get_config_dict(
640
        model if gguf_model_repo is None else gguf_model_repo,
641
642
643
        revision=revision,
        trust_remote_code=trust_remote_code,
        token=_get_hf_token(),
644
        **kwargs,
645
    )
646
647
648
649
650
651
652
    speculators_config = config_dict.get("speculators_config")

    if speculators_config is None:
        # No speculators config found, return original values
        return model, tokenizer, vllm_speculative_config

    # Speculators format detected - process overrides
653
    from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
654

655
    speculative_config = SpeculatorsConfig.extract_vllm_speculative_config(
656
657
        config_dict=config_dict
    )
658
659

    # Set the draft model to the speculators model
660
    speculative_config["model"] = model
661
662
663
664
665

    # Override model and tokenizer with the verifier model from config
    verifier_model = speculators_config["verifier"]["name_or_path"]
    model = tokenizer = verifier_model

666
    return model, tokenizer, speculative_config
667
668


669
def get_config(
670
    model: str | Path,
671
    trust_remote_code: bool,
672
673
674
675
676
    revision: str | None = None,
    code_revision: str | None = None,
    config_format: str | ConfigFormat = "auto",
    hf_overrides_kw: dict[str, Any] | None = None,
    hf_overrides_fn: Callable[[PretrainedConfig], PretrainedConfig] | None = None,
677
678
679
    **kwargs,
) -> PretrainedConfig:
    # Separate model folder from file path for GGUF models
680

681
    is_gguf = check_gguf_file(model)
682
683
684
685
    if is_gguf:
        kwargs["gguf_file"] = Path(model).name
        model = Path(model).parent

686
    if config_format == "auto":
687
        try:
688
689
690
            # First check for Mistral to avoid defaulting to
            # Transformers implementation.
            if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
691
                config_format = "mistral"
692
693
694
695
            elif is_gguf or file_or_path_exists(
                model, HF_CONFIG_NAME, revision=revision
            ):
                config_format = "hf"
696
697
698
            else:
                raise ValueError(
                    "Could not detect config format for no config file found. "
699
700
701
                    "With config_format 'auto', ensure your model has either "
                    "config.json (HF format) or params.json (Mistral format). "
                    "Otherwise please specify your_custom_config_format "
702
703
                    "in engine args for customized config parser."
                )
704
705
706
707
708
709
710
711
712
713
714

        except Exception as e:
            error_message = (
                "Invalid repository ID or local directory specified:"
                " '{model}'.\nPlease verify the following requirements:\n"
                "1. Provide a valid Hugging Face repository ID.\n"
                "2. Specify a local directory that contains a recognized "
                "configuration file.\n"
                "   - For Hugging Face models: ensure the presence of a "
                "'config.json'.\n"
                "   - For Mistral models: ensure the presence of a "
715
716
717
                "'params.json'.\n"
                "3. For GGUF: pass the local path of the GGUF checkpoint.\n"
                "   Loading GGUF from a remote repo directly is not yet "
718
719
                "supported.\n"
            ).format(model=model)
720
721

            raise ValueError(error_message) from e
722

723
724
725
726
727
728
729
730
    config_parser = get_config_parser(config_format)
    config_dict, config = config_parser.parse(
        model,
        trust_remote_code=trust_remote_code,
        revision=revision,
        code_revision=code_revision,
        **kwargs,
    )
731
732
733
    # Special architecture mapping check for GGUF models
    if is_gguf:
        if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
734
            raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
735
736
737
        model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
        config.update({"architectures": [model_type]})

738
739
740
    # Architecture mapping for models without explicit architectures field
    if not config.architectures:
        if config.model_type not in MODEL_MAPPING_NAMES:
741
742
743
744
745
746
747
748
            logger.warning(
                "Model config does not have a top-level 'architectures' field: "
                "expecting `hf_overrides={'architectures': ['...']}` to be passed "
                "in engine args."
            )
        else:
            model_type = MODEL_MAPPING_NAMES[config.model_type]
            config.update({"architectures": [model_type]})
749

750
751
752
753
754
755
    # ModelOpt 0.31.0 and after saves the quantization config in the model
    # config file.
    quantization_config = config_dict.get("quantization_config", None)

    # ModelOpt 0.29.0 and before saves the quantization config in a separate
    # "hf_quant_config.json" in the same directory as the model config file.
756
757
758
759
760
761
    if quantization_config is None and file_or_path_exists(
        model, "hf_quant_config.json", revision
    ):
        quantization_config = get_hf_file_to_dict(
            "hf_quant_config.json", model, revision
        )
762
763
764

    if quantization_config is not None:
        config.quantization_config = quantization_config
765
        # auto-enable DeepGEMM UE8M0 if model config requests it
766
        scale_fmt = quantization_config.get("scale_fmt", None)
767
        if scale_fmt in ("ue8m0",):
768
769
            if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0"):
                os.environ["VLLM_USE_DEEP_GEMM_E8M0"] = "1"
770
                logger.info_once(
771
772
                    (
                        "Detected quantization_config.scale_fmt=%s; "
773
                        "enabling UE8M0 for DeepGEMM."
774
                    ),
775
776
                    scale_fmt,
                )
777
            elif not envs.VLLM_USE_DEEP_GEMM_E8M0:
778
                logger.warning_once(
779
780
781
                    (
                        "Model config requests UE8M0 "
                        "(quantization_config.scale_fmt=%s), but "
782
783
                        "VLLM_USE_DEEP_GEMM_E8M0=0 is set; "
                        "UE8M0 for DeepGEMM disabled."
784
                    ),
785
786
                    scale_fmt,
                )
787

788
789
790
791
792
793
794
    if hf_overrides_kw:
        logger.debug("Overriding HF config with %s", hf_overrides_kw)
        config.update(hf_overrides_kw)
    if hf_overrides_fn:
        logger.debug("Overriding HF config with %s", hf_overrides_fn)
        config = hf_overrides_fn(config)

795
796
797
798
799
800
801
802
    # Exhaustively patch RoPE parameters everywhere they might be
    patch_rope_parameters(config)
    patch_rope_parameters(config.get_text_config())
    SubConfigs: TypeAlias = dict[str, PretrainedConfig]
    sub_configs: SubConfigs | None = getattr(config, "sub_configs", None)
    if sub_configs:
        for sub_config in sub_configs:
            patch_rope_parameters(getattr(config, sub_config))
803

804
805
806
    if trust_remote_code:
        maybe_register_config_serialize_by_value()

807
    return config
808
809


810
def try_get_local_file(
811
812
    model: str | Path, file_name: str, revision: str | None = "main"
) -> Path | None:
813
814
815
816
817
    file_path = Path(model) / file_name
    if file_path.is_file():
        return file_path
    else:
        try:
818
819
820
            cached_filepath = try_to_load_from_cache(
                repo_id=model, filename=file_name, revision=revision
            )
821
822
            if isinstance(cached_filepath, str):
                return Path(cached_filepath)
823
        except ValueError:
824
825
826
827
            ...
    return None


828
def get_hf_file_to_dict(
829
    file_name: str, model: str | Path, revision: str | None = "main"
830
):
831
    """
832
    Downloads a file from the Hugging Face Hub and returns
833
834
835
836
837
    its contents as a dictionary.

    Parameters:
    - file_name (str): The name of the file to download.
    - model (str): The name of the model on the Hugging Face Hub.
838
    - revision (str): The specific version of the model.
839
840

    Returns:
841
    - config_dict (dict): A dictionary containing
842
843
844
    the contents of the downloaded file.
    """

845
    file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
846

847
    if file_path is None:
848
849
        try:
            hf_hub_file = hf_hub_download(model, file_name, revision=revision)
850
851
        except huggingface_hub.errors.OfflineModeIsEnabled:
            return None
852
853
854
855
856
857
        except (
            RepositoryNotFoundError,
            RevisionNotFoundError,
            EntryNotFoundError,
            LocalEntryNotFoundError,
        ) as e:
858
859
860
861
            logger.debug("File or repository not found in hf_hub_download", e)
            return None
        except HfHubHTTPError as e:
            logger.warning(
862
                "Cannot connect to Hugging Face Hub. Skipping file download for '%s':",
863
                file_name,
864
865
                exc_info=e,
            )
866
867
868
869
            return None
        file_path = Path(hf_hub_file)

    if file_path is not None and file_path.is_file():
870
871
        with open(file_path) as file:
            return json.load(file)
872

873
874
875
    return None


876
@cache
877
def get_pooling_config(model: str, revision: str | None = "main") -> dict | None:
878
    """
879
880
881
    This function gets the pooling and normalize
    config from the model - only applies to
    sentence-transformers models.
882
883

    Args:
884
        model: The name of the Hugging Face model.
885
        revision: The specific version of the model to use.
886
            Defaults to 'main'.
887
888

    Returns:
889
        A dictionary containing the pooling type and whether
890
            normalization is used, or None if no pooling configuration is found.
891
892
893
    """

    modules_file_name = "modules.json"
894
895

    modules_dict = None
896
897
898
    if file_or_path_exists(
        model=model, config_name=modules_file_name, revision=revision
    ):
899
        modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)
900
901
902
903

    if modules_dict is None:
        return None

904
905
    logger.info("Found sentence-transformers modules configuration.")

906
907
908
909
910
911
912
913
    pooling = next(
        (
            item
            for item in modules_dict
            if item["type"] == "sentence_transformers.models.Pooling"
        ),
        None,
    )
914
    normalize = bool(
915
916
917
918
919
920
921
922
923
        next(
            (
                item
                for item in modules_dict
                if item["type"] == "sentence_transformers.models.Normalize"
            ),
            False,
        )
    )
924
925
926

    if pooling:
        pooling_file_name = "{}/config.json".format(pooling["path"])
927
        pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
928
        pooling_type_name = next(
929
930
            (item for item, val in pooling_dict.items() if val is True), None
        )
931
932
933
934

        if pooling_type_name is not None:
            pooling_type_name = get_pooling_config_name(pooling_type_name)

935
        logger.info("Found pooling configuration.")
936
937
938
939
940
        return {"pooling_type": pooling_type_name, "normalize": normalize}

    return None


941
def get_pooling_config_name(pooling_name: str) -> str | None:
942
943
944
945
946
947
948
949
950
    if "pooling_mode_" in pooling_name:
        pooling_name = pooling_name.replace("pooling_mode_", "")

    if "_" in pooling_name:
        pooling_name = pooling_name.split("_")[0]

    if "lasttoken" in pooling_name:
        pooling_name = "last"

951
    supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"]
952
953
    pooling_type_name = pooling_name.upper()

954
955
956
    if pooling_type_name in supported_pooling_types:
        return pooling_type_name

957
    raise NotImplementedError(f"Pooling type {pooling_type_name} not supported")
958
959


960
@cache
961
def get_sentence_transformer_tokenizer_config(
962
    model: str | Path, revision: str | None = "main"
963
):
964
    """
965
    Returns the tokenization configuration dictionary for a
966
967
968
    given Sentence Transformer BERT model.

    Parameters:
969
    - model (str|Path): The name of the Sentence Transformer
970
971
972
973
974
    BERT model.
    - revision (str, optional): The revision of the m
    odel to use. Defaults to 'main'.

    Returns:
975
    - dict: A dictionary containing the configuration parameters
976
977
    for the Sentence Transformer BERT model.
    """
978
979
980
981
982
983
984
985
986
987
    sentence_transformer_config_files = [
        "sentence_bert_config.json",
        "sentence_roberta_config.json",
        "sentence_distilbert_config.json",
        "sentence_camembert_config.json",
        "sentence_albert_config.json",
        "sentence_xlm-roberta_config.json",
        "sentence_xlnet_config.json",
    ]
    encoder_dict = None
988
989

    for config_file in sentence_transformer_config_files:
990
991
992
993
        if (
            try_get_local_file(model=model, file_name=config_file, revision=revision)
            is not None
        ):
994
            encoder_dict = get_hf_file_to_dict(config_file, model, revision)
995
996
            if encoder_dict:
                break
997

998
    if not encoder_dict and not Path(model).is_absolute():
999
1000
        try:
            # If model is on HuggingfaceHub, get the repo files
1001
1002
1003
            repo_files = list_repo_files(
                model, revision=revision, token=_get_hf_token()
            )
1004
        except Exception:
1005
1006
1007
1008
            repo_files = []

        for config_name in sentence_transformer_config_files:
            if config_name in repo_files:
1009
                encoder_dict = get_hf_file_to_dict(config_name, model, revision)
1010
1011
1012
                if encoder_dict:
                    break

1013
1014
1015
    if not encoder_dict:
        return None

1016
1017
    logger.info("Found sentence-transformers tokenize configuration.")

1018
1019
1020
1021
1022
    if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
        return encoder_dict
    return None


1023
def maybe_register_config_serialize_by_value() -> None:
1024
1025
    """Try to register HF model configuration class to serialize by value

1026
1027
1028
    If trust_remote_code is set, and the model's config file specifies an
    `AutoConfig` class, then the config class is typically an instance of
    a custom class imported from the HF modules cache.
1029

1030
    Examples:
1031

1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
    >>> from transformers import AutoConfig
    >>> klass = AutoConfig.from_pretrained(
    ...     "meta-llama/Meta-Llama-3-8B", trust_remote_code=True
    ... )
    >>> klass.__class__  # transformers.models.llama.configuration_llama.LlamaConfig
    >>> import transformers_modules  # error, not initialized
    >>> klass = AutoConfig.from_pretrained(
    ...     "deepseek-ai/DeepSeek-V2.5", trust_remote_code=True
    ... )
    >>> import transformers_modules  # success, initialized
    >>> klass.__class__  # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config

    In the DeepSeek example, the config class is an instance of a custom
    class that is not serializable by default. This class will not be
    importable in spawned workers, and won't exist at all on
    other nodes, which breaks serialization of the config.

    In this function we tell the cloudpickle serialization library to pass
    instances of these generated classes by value instead of by reference,
    i.e. the class definition is serialized along with its data so that the
    class module does not need to be importable on the receiving end.

    See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
    """  # noqa
1056
1057
    try:
        import transformers_modules
1058

1059
        transformers_modules_available = True
1060
    except ImportError:
1061
        transformers_modules_available = False
1062
1063
1064
1065
1066

    try:
        import multiprocessing
        import pickle

1067
1068
        import cloudpickle

1069
        from vllm.config import VllmConfig
1070

1071
1072
1073
        # Register multiprocessing reducers to handle cross-process
        # serialization of VllmConfig objects that may contain custom configs
        # from transformers_modules
1074
        def _reduce_config(config: VllmConfig):
1075
            return (pickle.loads, (cloudpickle.dumps(config),))
1076

1077
        multiprocessing.reducer.register(VllmConfig, _reduce_config)
1078

1079
1080
1081
1082
1083
        # Register transformers_modules with cloudpickle if available
        if transformers_modules_available:
            cloudpickle.register_pickle_by_value(transformers_modules)

            # ray vendors its own version of cloudpickle
1084
            from vllm.v1.executor.ray_utils import ray
1085

1086
1087
1088
            if ray:
                ray.cloudpickle.register_pickle_by_value(transformers_modules)

1089
1090
1091
1092
1093
1094
    except Exception as e:
        logger.warning(
            "Unable to register remote classes used by"
            " trust_remote_code with by-value serialization. This may"
            " lead to a later error. If remote code is not needed"
            " remove `--trust-remote-code`",
1095
1096
            exc_info=e,
        )
1097
1098


1099
def get_hf_image_processor_config(
1100
1101
1102
    model: str | Path,
    hf_token: bool | str | None = None,
    revision: str | None = None,
1103
    **kwargs,
1104
) -> dict[str, Any]:
1105
    # ModelScope does not provide an interface for image_processor
1106
    if envs.VLLM_USE_MODELSCOPE:
1107
        return dict()
1108
    # Separate model folder from file path for GGUF models
1109
    if check_gguf_file(model):
1110
        model = Path(model).parent
1111
1112
1113
    return get_image_processor_config(
        model, token=hf_token, revision=revision, **kwargs
    )
1114
1115


1116
1117
def get_hf_text_config(config: PretrainedConfig):
    """Get the "sub" config relevant to llm for multi modal models.
1118
    No op for pure text models.
1119
    """
1120
1121
1122
1123
1124
1125
1126
1127
1128
    text_config = config.get_text_config()

    if text_config is not config:
        # The code operates under the assumption that text_config should have
        # `num_attention_heads` (among others). Assert here to fail early
        # if transformers config doesn't align with this assumption.
        assert hasattr(text_config, "num_attention_heads")

    return text_config
1129
1130
1131
1132
1133


def try_get_generation_config(
    model: str,
    trust_remote_code: bool,
1134
1135
1136
    revision: str | None = None,
    config_format: str | ConfigFormat = "auto",
) -> GenerationConfig | None:
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
    try:
        return GenerationConfig.from_pretrained(
            model,
            revision=revision,
        )
    except OSError:  # Not found
        try:
            config = get_config(
                model,
                trust_remote_code=trust_remote_code,
                revision=revision,
1148
                config_format=config_format,
1149
1150
1151
1152
            )
            return GenerationConfig.from_model_config(config)
        except OSError:  # Not found
            return None
1153
1154


1155
1156
1157
def try_get_safetensors_metadata(
    model: str,
    *,
1158
    revision: str | None = None,
1159
1160
1161
1162
1163
):
    get_safetensors_metadata_partial = partial(
        get_safetensors_metadata,
        model,
        revision=revision,
1164
        token=_get_hf_token(),
1165
1166
1167
    )

    try:
1168
1169
1170
        return with_retry(
            get_safetensors_metadata_partial, "Error retrieving safetensors"
        )
1171
1172
    except Exception:
        return None
1173
1174
1175


def try_get_tokenizer_config(
1176
    pretrained_model_name_or_path: str | os.PathLike,
1177
    trust_remote_code: bool,
1178
1179
    revision: str | None = None,
) -> dict[str, Any] | None:
1180
1181
1182
1183
1184
1185
1186
1187
    try:
        return get_tokenizer_config(
            pretrained_model_name_or_path,
            trust_remote_code=trust_remote_code,
            revision=revision,
        )
    except Exception:
        return None
1188
1189


1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
@cache
def try_get_dense_modules(
    model: str | Path,
    revision: str | None = None,
) -> list[dict[str, Any]] | None:
    try:
        modules = get_hf_file_to_dict("modules.json", model, revision)
        if not modules:
            return None

        if isinstance(modules, dict):
            modules = modules.get("modules", [])

        dense_modules = [
            m for m in modules if m.get("type") == "sentence_transformers.models.Dense"
        ]
        if not dense_modules:
            return None

        layer_configs = []
        for module in dense_modules:
            folder = module.get("path", "")

            config_path = f"{folder}/config.json" if folder else "config.json"
            layer_config = get_hf_file_to_dict(config_path, model, revision)
            if not layer_config:
                continue
            layer_config["folder"] = folder
            layer_configs.append(layer_config)
        return layer_configs
    except Exception:
        return None


1224
1225
1226
def get_safetensors_params_metadata(
    model: str,
    *,
1227
    revision: str | None = None,
1228
1229
1230
1231
1232
1233
1234
1235
1236
) -> dict[str, Any]:
    """
    Get the safetensors metadata for remote model repository.
    """
    full_metadata = {}
    if (model_path := Path(model)).exists():
        safetensors_to_check = model_path.glob("*.safetensors")
        full_metadata = {
            param_name: info
1237
1238
1239
            for file_path in safetensors_to_check
            if file_path.is_file()
            for param_name, info in parse_safetensors_file_metadata(file_path).items()
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
        }
    else:
        repo_mt = try_get_safetensors_metadata(model, revision=revision)
        if repo_mt and (files_mt := repo_mt.files_metadata):
            full_metadata = {
                param_name: asdict(info)
                for file_mt in files_mt.values()
                for param_name, info in file_mt.tensors.items()
            }
    return full_metadata


1252
1253
1254
1255
1256
1257
1258
def _download_mistral_config_file(model, revision) -> dict:
    config_file_name = "params.json"
    config_dict = get_hf_file_to_dict(config_file_name, model, revision)
    if config_dict is None:
        raise ValueError(
            f"Failed to load mistral '{config_file_name}' config for model "
            f"{model}. Please check if the model is a mistral-format model "
1259
1260
            f"and if the config file exists."
        )
1261
1262
1263
1264
1265
1266
1267
1268
    assert isinstance(config_dict, dict)
    return config_dict


def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
    max_position_embeddings = 128_000
    try:
        trust_remote_code_val = kwargs.get("trust_remote_code", False)
1269
1270
1271
1272
1273
1274
        hf_config = get_config(
            model=model,
            trust_remote_code=trust_remote_code_val,
            revision=revision,
            config_format="hf",
        )
1275
1276
1277
1278
1279
1280
1281
        if hf_value := hf_config.get_text_config().max_position_embeddings:
            max_position_embeddings = hf_value
    except Exception as e:
        logger.warning(
            "The params.json file is missing 'max_position_embeddings'"
            " and could not get a value from the HF config."
            " Defaulting to 128000",
1282
1283
            exc_info=e,
        )
1284
1285

    return max_position_embeddings
1286
1287


1288
def get_model_path(model: str | Path, revision: str | None = None):
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
    if os.path.exists(model):
        return model
    assert huggingface_hub.constants.HF_HUB_OFFLINE
    common_kwargs = {
        "local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE,
        "revision": revision,
    }

    if envs.VLLM_USE_MODELSCOPE:
        from modelscope.hub.snapshot_download import snapshot_download
1299

1300
1301
1302
        return snapshot_download(model_id=model, **common_kwargs)

    from huggingface_hub import snapshot_download
1303

1304
    return snapshot_download(repo_id=model, **common_kwargs)
1305
1306


1307
def get_hf_file_bytes(
1308
1309
    file_name: str, model: str | Path, revision: str | None = "main"
) -> bytes | None:
1310
    """Get file contents from HuggingFace repository as bytes."""
1311
    file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)
1312
1313

    if file_path is None:
1314
1315
1316
        hf_hub_file = hf_hub_download(
            model, file_name, revision=revision, token=_get_hf_token()
        )
1317
1318
1319
        file_path = Path(hf_hub_file)

    if file_path is not None and file_path.is_file():
1320
        with open(file_path, "rb") as file:
1321
1322
1323
            return file.read()

    return None