model_config.py 27.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import logging
17
import math
18
import os
19
from enum import Enum, IntEnum, auto
20
from typing import List, Optional, Set, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21

22
import torch
Qubitium's avatar
Qubitium committed
23
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
26
27
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
Atream's avatar
Atream committed
28
    get_generation_config,
29
30
    get_hf_text_config,
)
31
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
32
from sglang.srt.server_args import ServerArgs
33
from sglang.srt.utils import get_bool_env_var, is_hip
34

35
36
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
39
40
41
42
class AttentionArch(IntEnum):
    MLA = auto()
    MHA = auto()


43
44
45
46
47
48
class ModelImpl(str, Enum):
    AUTO = "auto"
    SGLANG = "sglang"
    TRANSFORMERS = "transformers"


Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
class ModelConfig:
    def __init__(
        self,
52
        model_path: str,
Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
55
        context_length: Optional[int] = None,
56
        model_override_args: Optional[str] = None,
Chayenne's avatar
Chayenne committed
57
        is_embedding: Optional[bool] = None,
58
        enable_multimodal: Optional[bool] = None,
59
60
        dtype: str = "auto",
        quantization: Optional[str] = None,
61
        override_config_file: Optional[str] = None,
62
        is_draft_model: bool = False,
tarinkk's avatar
tarinkk committed
63
        hybrid_kvcache_ratio: Optional[float] = None,
64
        impl: Union[str, ModelImpl] = ModelImpl.AUTO,
Lianmin Zheng's avatar
Lianmin Zheng committed
65
    ) -> None:
66

67
68
69
        self.model_path = model_path
        self.revision = revision
        self.quantization = quantization
70
        self.impl = impl
71

72
        # Parse args
73
        self.maybe_pull_model_tokenizer_from_remote()
74
        self.model_override_args = json.loads(model_override_args)
75
76
77
78
        kwargs = {}
        if override_config_file and override_config_file.strip():
            kwargs["_configuration_file"] = override_config_file.strip()

79
        self.hf_config = get_config(
80
            self.model_path,
81
82
83
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
84
            **kwargs,
85
        )
86

Atream's avatar
Atream committed
87
88
89
90
91
92
93
        self.hf_generation_config = get_generation_config(
            self.model_path,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )

Qubitium's avatar
Qubitium committed
94
        self.hf_text_config = get_hf_text_config(self.hf_config)
Chang Su's avatar
Chang Su committed
95
96
97
        self.attention_chunk_size = getattr(
            self.hf_text_config, "attention_chunk_size", None
        )
tarinkk's avatar
tarinkk committed
98
99
100
101
102
103
104
105
106
107
108
109
        self.is_hybrid = is_hybrid_model(
            self.hf_config.architectures,
            hybrid_kvcache_ratio=hybrid_kvcache_ratio,
            context_length=context_length,
            attention_chunk_size=self.attention_chunk_size,
        )
        if self.is_hybrid is not None:
            self.swa_attention_layer_ids, self.full_attention_layer_ids = (
                get_hybrid_layer_ids(
                    self.hf_config.architectures, self.hf_text_config.num_hidden_layers
                )
            )
110

111
        if enable_multimodal is None:
112
113
114
115
116
            mm_disabled_models = [
                "Gemma3ForConditionalGeneration",
                "Llama4ForConditionalGeneration",
            ]
            if self.hf_config.architectures[0] in mm_disabled_models:
117
                enable_multimodal = False
118
                logger.info(
119
                    f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
120
                )
121
122
123
            else:
                enable_multimodal = True

124
125
126
127
128
129
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
        ):
            self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"

130
131
        if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
            self.hf_config.architectures[0] = "MiMoMTP"
132
        # Check model type
Chayenne's avatar
Chayenne committed
133
134
135
        self.is_generation = is_generation_model(
            self.hf_config.architectures, is_embedding
        )
136
137
138
139
140
141
142
143
144
145
146
147
        self.is_multimodal = enable_multimodal and is_multimodal_model(
            self.hf_config.architectures
        )
        self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
            self.hf_config.architectures
        )
        self.is_image_gen = enable_multimodal and is_image_gen_model(
            self.hf_config.architectures
        )
        self.is_audio_model = enable_multimodal and is_audio_model(
            self.hf_config.architectures
        )
148
149
150
151
        self.is_multimodal_chunked_prefill_supported = (
            enable_multimodal
            and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
        )
152
        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
153
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
154
155

        # Derive context length
156
        derived_context_len = get_context_length(self.hf_text_config)
157
        if context_length is not None:
158
            if context_length > derived_context_len:
159
                if get_bool_env_var(
160
                    "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
161
                ):
162
163
164
165
166
167
168
169
170
171
172
173
174
                    logger.warning(
                        f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                        f"This may lead to incorrect model outputs or CUDA errors."
                    )
                    self.context_len = context_length
                else:
                    raise ValueError(
                        f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                        f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
                        f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
                    )
            else:
                self.context_len = context_length
175
        else:
176
            self.context_len = derived_context_len
Lianmin Zheng's avatar
Lianmin Zheng committed
177

Liangsheng Yin's avatar
Liangsheng Yin committed
178
        # Unify the config keys for hf_text_config
Liangsheng Yin's avatar
Liangsheng Yin committed
179
        self.head_dim = getattr(
Liangsheng Yin's avatar
Liangsheng Yin committed
180
            self.hf_text_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
181
            "head_dim",
Liangsheng Yin's avatar
Liangsheng Yin committed
182
            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
Liangsheng Yin's avatar
Liangsheng Yin committed
183
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
184

185
        # FIXME: temporary special judge for MLA architecture
HandH1998's avatar
HandH1998 committed
186
187
188
        if (
            "DeepseekV2ForCausalLM" in self.hf_config.architectures
            or "DeepseekV3ForCausalLM" in self.hf_config.architectures
189
            or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
HandH1998's avatar
HandH1998 committed
190
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
191
            self.head_dim = 256
192
193
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
194
            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
195
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
196
197
198
199
200
201
202
203
204
205
206
207
            self.v_head_dim = self.hf_config.v_head_dim

            # Handle rope scaling with yarn
            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            if self.hf_config.rope_scaling:
                mscale_all_dim = self.hf_config.rope_scaling.get(
                    "mscale_all_dim", False
                )
                scaling_factor = self.hf_config.rope_scaling["factor"]
                mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
                self.scaling = self.scaling * mscale * mscale

William's avatar
William committed
208
209
210
211
212
        elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 128
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
213
214
215
        elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
            self.hf_text_config, "use_mla", True
        ):
216
217
218
219
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
220
221
222
223
224
225
226
        elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
            self.v_head_dim = self.hf_text_config.v_head_dim
            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
227
        else:
228
229
230
            if (
                "MistralModel" in self.hf_config.architectures
                or "MixtralForCausalLM" in self.hf_config.architectures
231
                or "MistralForCausalLM" in self.hf_config.architectures
232
233
234
235
236
237
238
239
240
241
242
243
            ):
                if getattr(self, "head_dim", None) is None:
                    self.head_dim = (
                        self.hf_config.hidden_size // self.hf_config.num_attention_heads
                    )
                    # In transformers==4.52.3, the head_dim is null in MistralConfig
                    if (
                        not hasattr(self.hf_text_config, "head_dim")
                        or self.hf_text_config.head_dim is None
                    ):
                        setattr(self.hf_text_config, "head_dim", self.head_dim)

244
            self.attention_arch = AttentionArch.MHA
Liangsheng Yin's avatar
Liangsheng Yin committed
245

Liangsheng Yin's avatar
Liangsheng Yin committed
246
247
248
249
        self.num_attention_heads = self.hf_text_config.num_attention_heads
        self.num_key_value_heads = getattr(
            self.hf_text_config, "num_key_value_heads", None
        )
250
251
252
253
254
255
256

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

257
258
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Liangsheng Yin's avatar
Liangsheng Yin committed
259
260
261
        self.hidden_size = self.hf_text_config.hidden_size
        self.num_hidden_layers = self.hf_text_config.num_hidden_layers
        self.vocab_size = self.hf_text_config.vocab_size
Qubitium's avatar
Qubitium committed
262

263
        # Verify quantization
264
265
        self._verify_quantization()

266
        # Cache attributes
267
        self.hf_eos_token_id = self.get_hf_eos_token_id()
268
269
270
271
272
273
274

        config = self.hf_config

        # multimodal
        self.image_token_id = getattr(config, "image_token_id", None) or getattr(
            config, "image_token_index", None
        )
275

276
277
278
279
280
281
282
283
284
285
286
287
    @staticmethod
    def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
        return ModelConfig(
            model_path=model_path or server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            enable_multimodal=server_args.enable_multimodal,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
tarinkk's avatar
tarinkk committed
288
            hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
289
            impl=server_args.impl,
290
291
292
            **kwargs,
        )

Qubitium's avatar
Qubitium committed
293
294
295
296
297
298
299
300
301
    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
302
303
304
305
306
307
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
308
309
310
311
312
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
313
314
315
316
317
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
318
319
320
321
322
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
Qubitium's avatar
Qubitium committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
349
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
350

351
352
353
354
355
356
    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
            # compressed-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
357
358
359
360
        if quant_cfg is None:
            # check if is modelopt model -- modelopt doesn't have corresponding field
            # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
            # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
361
            is_local = os.path.exists(self.model_path)
362
363
364
365
366
367
368
369
            modelopt_quant_config = {"quant_method": "modelopt"}
            if not is_local:
                from huggingface_hub import HfApi

                hf_api = HfApi()
                if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
                    quant_cfg = modelopt_quant_config
            elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
370
371
372
373
374
375
376
377
378
379
380
                quant_config_file = os.path.join(
                    self.model_path, "hf_quant_config.json"
                )
                with open(quant_config_file) as f:
                    quant_config_dict = json.load(f)
                json_quant_configs = quant_config_dict["quantization"]
                quant_algo = json_quant_configs.get("quant_algo", None)
                if quant_algo == "MIXED_PRECISION":
                    quant_cfg = {"quant_method": "w4afp8"}
                else:
                    quant_cfg = modelopt_quant_config
381
382
383
384
        return quant_cfg

    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _verify_quantization(self) -> None:
385
        supported_quantization = [*QUANTIZATION_METHODS]
386
387
388
389
390
391
392
        rocm_supported_quantization = [
            "awq",
            "gptq",
            "fp8",
            "compressed_tensors",
            "compressed-tensors",
            "fbgemm_fp8",
HandH1998's avatar
HandH1998 committed
393
            "w8a8_fp8",
394
            "petit_nvfp4",
395
396
397
398
399
400
401
402
403
404
405
406
        ]
        optimized_quantization_methods = [
            "fp8",
            "marlin",
            "modelopt",
            "gptq_marlin_24",
            "gptq_marlin",
            "awq_marlin",
            "fbgemm_fp8",
            "compressed_tensors",
            "compressed-tensors",
            "experts_int8",
407
            "w8a8_int8",
HandH1998's avatar
HandH1998 committed
408
            "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
409
            "moe_wna16",
HandH1998's avatar
HandH1998 committed
410
            "qoq",
411
            "w4afp8",
412
            "petit_nvfp4",
413
        ]
414
        compatible_quantization_methods = {
415
            "modelopt_fp4": ["modelopt"],
416
            "petit_nvfp4": ["modelopt"],
HandH1998's avatar
HandH1998 committed
417
418
            "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
            "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
419
        }
420
421
422
423
424
425
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        quant_cfg = self._parse_quant_hf_config()

426
        if quant_cfg is not None:
427
428
429
            quant_method = quant_cfg.get(
                "quant_method", "" if not self.quantization else self.quantization
            ).lower()
430
431

            # Detect which checkpoint is it
432
            for _, method in QUANTIZATION_METHODS.items():
433
434
435
436
437
438
439
440
441
442
443
444
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization
                )
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break

            # Verify quantization configurations.
            if self.quantization is None:
                self.quantization = quant_method
            elif self.quantization != quant_method:
445
446
447
448
449
450
451
452
453
454
455
                if (
                    self.quantization not in compatible_quantization_methods
                    or quant_method
                    not in compatible_quantization_methods[self.quantization]
                ):
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization})."
                    )
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}."
                )
            if is_hip() and self.quantization not in rocm_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm."
                )
            if self.quantization not in optimized_quantization_methods:
                logger.warning(
                    "%s quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.",
                    self.quantization,
                )

476
477
478
479
480
    def get_hf_eos_token_id(self) -> Optional[Set[int]]:
        eos_ids = getattr(self.hf_config, "eos_token_id", None)
        if eos_ids:
            # it can be either int or list of int
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
Atream's avatar
Atream committed
481
482
483
484
485
486
487
488
489
490
491
492
493
        if eos_ids is None:
            eos_ids = set()
        if self.hf_generation_config:
            generation_eos_ids = getattr(
                self.hf_generation_config, "eos_token_id", None
            )
            if generation_eos_ids:
                generation_eos_ids = (
                    {generation_eos_ids}
                    if isinstance(generation_eos_ids, int)
                    else set(generation_eos_ids)
                )
                eos_ids = eos_ids | generation_eos_ids
494
495
        return eos_ids

496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    def maybe_pull_model_tokenizer_from_remote(self) -> None:
        """
        Pull the model config files to a temporary
        directory in case of remote.

        Args:
            model: The model name or path.

        """
        from sglang.srt.connector import create_remote_connector
        from sglang.srt.utils import is_remote_url

        if is_remote_url(self.model_path):
            logger.info("Pulling model configs from remote...")
            # BaseConnector implements __del__() to clean up the local dir.
            # Since config files need to exist all the time, so we DO NOT use
            # with statement to avoid closing the client.
            client = create_remote_connector(self.model_path)
            if is_remote_url(self.model_path):
                client.pull_files(allow_pattern=["*config.json"])
                self.model_weights = self.model_path
                self.model_path = client.get_local_dir()

Qubitium's avatar
Qubitium committed
519

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: Union[str, torch.dtype],
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
HandH1998's avatar
HandH1998 committed
538
539
    if isinstance(config_dtype, str):
        config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
540
541
542
543
544
545
546
    if config_dtype is None:
        config_dtype = torch.float32

    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
547
548
549
550
551
                if config.model_type.startswith("gemma"):
                    if config.model_type == "gemma":
                        gemma_version = ""
                    else:
                        gemma_version = config.model_type[5]
552
                    logger.info(
553
                        f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16."
                    )
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
        else:
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
            pass
        else:
            # Casting between float16 and bfloat16 is allowed with a warning.
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

    return torch_dtype


590
591
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
Mick's avatar
Mick committed
592
    # 1. Check the model architecture
593
594
595
596
597
598
599
    # 2. check the `is_embedding` server args

    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
RangiLyu's avatar
RangiLyu committed
600
        or "InternLM2ForRewardModel" in model_architectures
601
        or "Qwen2ForRewardModel" in model_architectures
602
        or "Qwen2ForSequenceClassification" in model_architectures
uylnap's avatar
uylnap committed
603
        or "CLIPModel" in model_architectures
woodx's avatar
woodx committed
604
605
606
607
608
        or "BertModel" in model_architectures
        or "Contriever" in model_architectures
        or "BertForSequenceClassification" in model_architectures
        or "XLMRobertaModel" in model_architectures
        or "XLMRobertaForSequenceClassification" in model_architectures
609
610
611
612
613
614
    ):
        return False
    else:
        return not is_embedding


Mick's avatar
Mick committed
615
multimodal_model_archs = [
616
    "CLIPModel",
617
    "DeepseekVL2ForCausalLM",
618
    "Gemma3ForConditionalGeneration",
619
    "Gemma3nForConditionalGeneration",
Mick's avatar
Mick committed
620
621
    "Grok1VForCausalLM",
    "Grok1AForCausalLM",
Mick's avatar
Mick committed
622
    "LlavaLlamaForCausalLM",
Mick's avatar
Mick committed
623
    "Llama4ForConditionalGeneration",
Mick's avatar
Mick committed
624
625
    "LlavaMistralForCausalLM",
    "LlavaQwenForCausalLM",
Kiv Chen's avatar
Kiv Chen committed
626
    "LlavaForConditionalGeneration",
Mick's avatar
Mick committed
627
628
629
    "LlavaVidForCausalLM",
    "MiniCPMO",
    "MiniCPMV",
Kiv Chen's avatar
Kiv Chen committed
630
    "Mistral3ForConditionalGeneration",
Mick's avatar
Mick committed
631
    "MultiModalityCausalLM",
Mick's avatar
Mick committed
632
    "MllamaForConditionalGeneration",
Leng Yue's avatar
Leng Yue committed
633
    "Qwen2AudioForConditionalGeneration",
Mick's avatar
Mick committed
634
635
    "Qwen2VLForConditionalGeneration",
    "Qwen2_5_VLForConditionalGeneration",
636
    "KimiVLForConditionalGeneration",
xm:D's avatar
xm:D committed
637
    "InternVLChatModel",
638
    "Phi4MMForCausalLM",
Zijian's avatar
Zijian committed
639
    "VILAForConditionalGeneration",
Mick's avatar
Mick committed
640
641
642
]


643
def is_multimodal_model(model_architectures: List[str]):
Mick's avatar
Mick committed
644
645
646
    if any(
        multi_model_arch in model_architectures
        for multi_model_arch in multimodal_model_archs
647
648
649
650
651
652
    ):
        return True
    else:
        return False


653
654
655
656
657
658
659
660
661
662
663
664
def is_multimodal_gen_model(model_architectures: List[str]):
    return False


def is_image_gen_model(model_architectures: List[str]):
    return False


def is_audio_model(model_architectures: List[str]):
    return False


665
666
def is_encoder_decoder_model(model_architectures: List[str]):
    return "MllamaForConditionalGeneration" in model_architectures
667
668


669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
    """Check if chunked prefill is supported for a MultiModal model."""
    unsupported = [
        "Grok1VForCausalLM",
        "Grok1AForCausalLM",
        "LlavaLlamaForCausalLM",
        "MllamaForConditionalGeneration",
        "CLIPModel",
    ]
    if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
        return False
    else:
        return True


684
685
686
687
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0
tarinkk's avatar
tarinkk committed
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716


def is_hybrid_model(
    model_architectures: List[str],
    hybrid_kvcache_ratio: Optional[float],
    context_length: Optional[int],
    attention_chunk_size: Optional[int],
):
    if hybrid_kvcache_ratio is None:
        return None
    elif (
        hybrid_kvcache_ratio > 0
        and model_architectures[0] == "Llama4ForConditionalGeneration"
        and context_length > attention_chunk_size
    ):
        return hybrid_kvcache_ratio
    else:
        return None


def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
    if "Llama4ForConditionalGeneration" in model_architectures:
        swa_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
        ]
        full_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
        ]
    else:
Hanming Lu's avatar
Hanming Lu committed
717
718
        swa_attention_layer_ids = None
        full_attention_layer_ids = None
tarinkk's avatar
tarinkk committed
719
    return swa_attention_layer_ids, full_attention_layer_ids