model_config.py 33.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import logging
17
import math
18
import os
19
from enum import Enum, IntEnum, auto
20
from typing import List, Optional, Set, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21

22
import torch
Qubitium's avatar
Qubitium committed
23
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
from sglang.srt.environ import envs
26
27
28
29
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_hip, retry
from sglang.srt.utils.hf_transformers_utils import (
30
31
    get_config,
    get_context_length,
Atream's avatar
Atream committed
32
    get_generation_config,
33
    get_hf_text_config,
34
    get_sparse_attention_config,
35
)
36
from sglang.utils import is_in_ci
37

38
39
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
40

41
42
43
44
45
class AttentionArch(IntEnum):
    MLA = auto()
    MHA = auto()


46
47
48
49
50
51
class ModelImpl(str, Enum):
    AUTO = "auto"
    SGLANG = "sglang"
    TRANSFORMERS = "transformers"


Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
54
class ModelConfig:
    def __init__(
        self,
55
        model_path: str,
Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
58
        context_length: Optional[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
59
        model_override_args: str = "{}",
Chayenne's avatar
Chayenne committed
60
        is_embedding: Optional[bool] = None,
61
        enable_multimodal: Optional[bool] = None,
62
63
        dtype: str = "auto",
        quantization: Optional[str] = None,
64
        override_config_file: Optional[str] = None,
65
        is_draft_model: bool = False,
tarinkk's avatar
tarinkk committed
66
        hybrid_kvcache_ratio: Optional[float] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
67
        model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
Lianmin Zheng's avatar
Lianmin Zheng committed
68
    ) -> None:
69
        # Parse args
70
71
72
        self.model_path = model_path
        self.revision = revision
        self.quantization = quantization
73
        self.is_draft_model = is_draft_model
Lianmin Zheng's avatar
Lianmin Zheng committed
74
        self.model_impl = model_impl
75
76
77

        # Get hf config
        self._maybe_pull_model_tokenizer_from_remote()
78
        self.model_override_args = json.loads(model_override_args)
79
80
81
        kwargs = {}
        if override_config_file and override_config_file.strip():
            kwargs["_configuration_file"] = override_config_file.strip()
82
        self.hf_config = get_config(
83
            self.model_path,
84
85
86
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
87
            **kwargs,
88
        )
89
        self.hf_text_config = get_hf_text_config(self.hf_config)
Atream's avatar
Atream committed
90
91
92
93
94
95
96
        self.hf_generation_config = get_generation_config(
            self.model_path,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        # Set enable_multimodal
        if enable_multimodal is None:
            mm_disabled_models = [
                "Gemma3ForConditionalGeneration",
                "Llama4ForConditionalGeneration",
                "Step3VLForConditionalGeneration",
            ]
            if self.hf_config.architectures[0] in mm_disabled_models:
                enable_multimodal = False
                logger.info(
                    f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
                )
            else:
                enable_multimodal = True

        # Config draft model
        self._config_draft_model()

        # Check model type
Chang Su's avatar
Chang Su committed
116
117
118
        self.attention_chunk_size = getattr(
            self.hf_text_config, "attention_chunk_size", None
        )
tarinkk's avatar
tarinkk committed
119
120
121
122
123
124
125
126
127
128
129
130
        self.is_hybrid = is_hybrid_model(
            self.hf_config.architectures,
            hybrid_kvcache_ratio=hybrid_kvcache_ratio,
            context_length=context_length,
            attention_chunk_size=self.attention_chunk_size,
        )
        if self.is_hybrid is not None:
            self.swa_attention_layer_ids, self.full_attention_layer_ids = (
                get_hybrid_layer_ids(
                    self.hf_config.architectures, self.hf_text_config.num_hidden_layers
                )
            )
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        self.is_generation = is_generation_model(
            self.hf_config.architectures, is_embedding
        )
        self.is_multimodal = enable_multimodal and is_multimodal_model(
            self.hf_config.architectures
        )
        self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
            self.hf_config.architectures
        )
        self.is_image_gen = enable_multimodal and is_image_gen_model(
            self.hf_config.architectures
        )
        self.is_audio_model = enable_multimodal and is_audio_model(
            self.hf_config.architectures
        )
        self.is_multimodal_chunked_prefill_supported = (
            enable_multimodal
            and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
        )
        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
152

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        # Derive context length and model shapes
        self._derive_context_length(context_length)
        self._derive_model_shapes()

        # Verify quantization
        self._verify_quantization()

        # Verify dual-chunk attention config
        self._verify_dual_chunk_attention_config()

        # Cache attributes
        self.hf_eos_token_id = self._get_hf_eos_token_id()

        # multimodal
        self.image_token_id = getattr(
            self.hf_config, "image_token_id", None
        ) or getattr(self.hf_config, "image_token_index", None)

    @staticmethod
    def from_server_args(
        server_args: ServerArgs,
        model_path: str = None,
        model_revision: str = None,
        **kwargs,
    ):
        return ModelConfig(
            model_path=model_path or server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=model_revision or server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            enable_multimodal=server_args.enable_multimodal,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
            hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
            model_impl=server_args.model_impl,
            **kwargs,
        )

    def _config_draft_model(self):
        is_draft_model = self.is_draft_model
195

196
197
198
199
200
201
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
        ):
            self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"

Yuxuan Zhang's avatar
Yuxuan Zhang committed
202
203
204
        if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
            self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"

205
206
207
208
209
210
211
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
        ):
            self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
            self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers

212
213
        if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
            self.hf_config.architectures[0] = "MiMoMTP"
strgrb's avatar
strgrb committed
214
215
216
217
218
        if is_draft_model and self.hf_config.architectures[0] in [
            "BailingMoeV2ForCausalLM",
            "BailingMoeForCausalLM",
        ]:
            self.hf_config.architectures[0] = "BailingMoeForCausalLMNextN"
219
220
221
222
223
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
        ):
            self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
224

Yi Zhang's avatar
Yi Zhang committed
225
226
        if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
            self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
227
            self.hf_config.num_nextn_predict_layers = 1
Yi Zhang's avatar
Yi Zhang committed
228

229
230
    def _derive_context_length(self, context_length: int):
        is_draft_model = self.is_draft_model
231
        derived_context_len = get_context_length(self.hf_text_config)
232

233
        if context_length is not None:
234
            if context_length > derived_context_len:
235
236
237
238
239
240
                reason = "Target model's" if is_draft_model else "User-specified"
                msg = (
                    f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                    f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
                )
                if (
241
                    envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.get()
242
                    or is_in_ci()  # FIXME: fix this special case
243
                ):
244
                    logger.warning(msg)
245
                    self.context_len = context_length
246
247
248
249
250
                    if is_draft_model:
                        self.hf_text_config.max_position_embeddings = context_length
                        logger.warning(
                            f"Overriding the draft model's max_position_embeddings to {context_length}."
                        )
251
252
                else:
                    raise ValueError(
253
                        f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
254
255
256
                    )
            else:
                self.context_len = context_length
257
        else:
258
            self.context_len = derived_context_len
Lianmin Zheng's avatar
Lianmin Zheng committed
259

260
261
262
263
        # Transfer context_len to HuggingFace config so models can access it
        self.hf_config.context_len = self.context_len

    def _derive_model_shapes(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
264
        # Unify the config keys for hf_text_config
Liangsheng Yin's avatar
Liangsheng Yin committed
265
        self.head_dim = getattr(
Liangsheng Yin's avatar
Liangsheng Yin committed
266
            self.hf_text_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
267
            "head_dim",
Liangsheng Yin's avatar
Liangsheng Yin committed
268
            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
Liangsheng Yin's avatar
Liangsheng Yin committed
269
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
270

271
        # FIXME: temporary special judge for MLA architecture
HandH1998's avatar
HandH1998 committed
272
273
274
        if (
            "DeepseekV2ForCausalLM" in self.hf_config.architectures
            or "DeepseekV3ForCausalLM" in self.hf_config.architectures
275
            or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
276
277
            or "LongcatFlashForCausalLM" in self.hf_config.architectures
            or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
278
            or "DotsVLMForCausalLM" in self.hf_config.architectures
HandH1998's avatar
HandH1998 committed
279
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
280
            self.head_dim = 256
281
282
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
283
            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
284
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
285
286
287
288
289
290
291
292
293
294
295
296
            self.v_head_dim = self.hf_config.v_head_dim

            # Handle rope scaling with yarn
            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            if self.hf_config.rope_scaling:
                mscale_all_dim = self.hf_config.rope_scaling.get(
                    "mscale_all_dim", False
                )
                scaling_factor = self.hf_config.rope_scaling["factor"]
                mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
                self.scaling = self.scaling * mscale * mscale

William's avatar
William committed
297
298
299
300
301
        elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 128
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
302
303
304
        elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
            self.hf_text_config, "use_mla", True
        ):
305
306
307
308
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
309
310
311
312
313
314
315
        elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
            self.v_head_dim = self.hf_text_config.v_head_dim
            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
316
        else:
317
318
319
            if (
                "MistralModel" in self.hf_config.architectures
                or "MixtralForCausalLM" in self.hf_config.architectures
320
                or "MistralForCausalLM" in self.hf_config.architectures
321
322
323
324
325
326
327
328
329
330
331
332
            ):
                if getattr(self, "head_dim", None) is None:
                    self.head_dim = (
                        self.hf_config.hidden_size // self.hf_config.num_attention_heads
                    )
                    # In transformers==4.52.3, the head_dim is null in MistralConfig
                    if (
                        not hasattr(self.hf_text_config, "head_dim")
                        or self.hf_text_config.head_dim is None
                    ):
                        setattr(self.hf_text_config, "head_dim", self.head_dim)

333
            self.attention_arch = AttentionArch.MHA
Liangsheng Yin's avatar
Liangsheng Yin committed
334

Liangsheng Yin's avatar
Liangsheng Yin committed
335
336
337
338
        self.num_attention_heads = self.hf_text_config.num_attention_heads
        self.num_key_value_heads = getattr(
            self.hf_text_config, "num_key_value_heads", None
        )
339
340
341
342
343
344
345

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

346
347
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Liangsheng Yin's avatar
Liangsheng Yin committed
348
349
        self.hidden_size = self.hf_text_config.hidden_size
        self.num_hidden_layers = self.hf_text_config.num_hidden_layers
350
351
352
        self.num_attention_layers = self.num_hidden_layers
        if "LongcatFlashForCausalLM" in self.hf_config.architectures:
            self.num_attention_layers = self.num_hidden_layers * 2
353
354
355
        self.num_nextn_predict_layers = getattr(
            self.hf_text_config, "num_nextn_predict_layers", None
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
356
        self.vocab_size = self.hf_text_config.vocab_size
Qubitium's avatar
Qubitium committed
357

358
359
360
361
362
363
364
    def get_total_num_attention_heads(self) -> int:
        return self.num_attention_heads

    def get_num_attention_heads(self, tensor_parallel_size) -> int:
        total_num_attention_heads = self.num_attention_heads
        return max(1, total_num_attention_heads // tensor_parallel_size)

Qubitium's avatar
Qubitium committed
365
366
367
368
369
370
371
372
373
    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
374
375
376
377
378
379
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
380
381
382
383
384
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
385
386
387
388
389
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
390
391
392
393
394
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
395
396
397
398
399
400
401
402
403
404
405
406
407
        if self.hf_config.model_type in ["nemotron-nas"]:
            nkvh = {
                self.hf_config.num_attention_heads // block.attention.n_heads_in_group
                for block in self.hf_config.block_configs
                if not block.attention.no_op
            }
            if len(nkvh) == 0:
                raise RuntimeError("Couldn't determine number of kv heads")
            if len(nkvh) > 1:
                raise ValueError(
                    "Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang"
                )
            return next(iter(nkvh))
Qubitium's avatar
Qubitium committed
408
409
410
411
412
413
414
415
416

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
Chang Su's avatar
Chang Su committed
417
418
            # For Step3
            "num_attention_groups",
Qubitium's avatar
Qubitium committed
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
436
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
437

438
439
440
441
442
443
    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
            # compressed-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
444
        if quant_cfg is None:
445
            # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
446
447
            # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
            # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
448
            # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
449
            is_local = os.path.exists(self.model_path)
450
451
            modelopt_quant_config = {"quant_method": "modelopt"}
            if not is_local:
452
453
454
455
456
457
                import huggingface_hub

                try:
                    from huggingface_hub import HfApi

                    hf_api = HfApi()
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472

                    def check_hf_quant_config():
                        return hf_api.file_exists(
                            self.model_path, "hf_quant_config.json"
                        )

                    # Retry HF API call up to 3 times
                    file_exists = retry(
                        check_hf_quant_config,
                        max_retry=2,
                        initial_delay=1.0,
                        max_delay=5.0,
                    )

                    if file_exists:
473
                        quant_cfg = modelopt_quant_config
474

475
476
477
478
                except huggingface_hub.errors.OfflineModeIsEnabled:
                    logger.warning(
                        "Offline mode is enabled, skipping hf_quant_config.json check"
                    )
479
480
481
482
                except Exception as e:
                    logger.warning(
                        f"Failed to check hf_quant_config.json: {self.model_path} {e}"
                    )
483
484

            elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
485
486
487
488
489
490
491
492
493
494
495
                quant_config_file = os.path.join(
                    self.model_path, "hf_quant_config.json"
                )
                with open(quant_config_file) as f:
                    quant_config_dict = json.load(f)
                json_quant_configs = quant_config_dict["quantization"]
                quant_algo = json_quant_configs.get("quant_algo", None)
                if quant_algo == "MIXED_PRECISION":
                    quant_cfg = {"quant_method": "w4afp8"}
                else:
                    quant_cfg = modelopt_quant_config
496
497
498
499
        return quant_cfg

    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _verify_quantization(self) -> None:
500
        supported_quantization = [*QUANTIZATION_METHODS]
501
502
503
504
505
506
507
        rocm_supported_quantization = [
            "awq",
            "gptq",
            "fp8",
            "compressed_tensors",
            "compressed-tensors",
            "fbgemm_fp8",
HandH1998's avatar
HandH1998 committed
508
            "w8a8_fp8",
509
            "petit_nvfp4",
510
511
            "quark",
            "mxfp4",
512
513
514
515
516
517
518
519
520
521
522
523
        ]
        optimized_quantization_methods = [
            "fp8",
            "marlin",
            "modelopt",
            "gptq_marlin_24",
            "gptq_marlin",
            "awq_marlin",
            "fbgemm_fp8",
            "compressed_tensors",
            "compressed-tensors",
            "experts_int8",
524
            "w8a8_int8",
HandH1998's avatar
HandH1998 committed
525
            "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
526
            "moe_wna16",
HandH1998's avatar
HandH1998 committed
527
            "qoq",
528
            "w4afp8",
529
            "petit_nvfp4",
530
        ]
531
        compatible_quantization_methods = {
532
            "modelopt_fp4": ["modelopt"],
533
            "petit_nvfp4": ["modelopt"],
HandH1998's avatar
HandH1998 committed
534
535
            "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
            "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
536
        }
537
538
539
540
541
542
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        quant_cfg = self._parse_quant_hf_config()

543
        if quant_cfg is not None:
544
545
546
            quant_method = quant_cfg.get(
                "quant_method", "" if not self.quantization else self.quantization
            ).lower()
547
548

            # Detect which checkpoint is it
549
            for _, method in QUANTIZATION_METHODS.items():
550
551
552
553
554
555
556
557
558
559
560
561
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization
                )
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break

            # Verify quantization configurations.
            if self.quantization is None:
                self.quantization = quant_method
            elif self.quantization != quant_method:
562
563
564
565
566
567
568
569
570
571
572
                if (
                    self.quantization not in compatible_quantization_methods
                    or quant_method
                    not in compatible_quantization_methods[self.quantization]
                ):
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization})."
                    )
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}."
                )
            if is_hip() and self.quantization not in rocm_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm."
                )
            if self.quantization not in optimized_quantization_methods:
                logger.warning(
                    "%s quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.",
                    self.quantization,
                )

593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
    def _verify_dual_chunk_attention_config(self) -> None:
        if hasattr(self.hf_config, "dual_chunk_attention_config"):
            # Try loading the sparse attention config
            sparse_attn_config = get_sparse_attention_config(self.model_path)
            if not sparse_attn_config:
                return
            self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
                sparse_attn_config
            )
            if (
                "sparse_attention_enabled"
                not in self.hf_config.dual_chunk_attention_config
            ):
                self.hf_config.dual_chunk_attention_config[
                    "sparse_attention_enabled"
                ] = True

610
    def _get_hf_eos_token_id(self) -> Optional[Set[int]]:
611
        eos_ids = getattr(self.hf_config, "eos_token_id", None)
Minho Ryu's avatar
Minho Ryu committed
612
        if eos_ids is not None:
613
614
            # it can be either int or list of int
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
Atream's avatar
Atream committed
615
616
617
618
619
620
621
622
623
624
625
626
627
        if eos_ids is None:
            eos_ids = set()
        if self.hf_generation_config:
            generation_eos_ids = getattr(
                self.hf_generation_config, "eos_token_id", None
            )
            if generation_eos_ids:
                generation_eos_ids = (
                    {generation_eos_ids}
                    if isinstance(generation_eos_ids, int)
                    else set(generation_eos_ids)
                )
                eos_ids = eos_ids | generation_eos_ids
628
629
        return eos_ids

630
    def _maybe_pull_model_tokenizer_from_remote(self) -> None:
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        """
        Pull the model config files to a temporary
        directory in case of remote.

        Args:
            model: The model name or path.

        """
        from sglang.srt.connector import create_remote_connector
        from sglang.srt.utils import is_remote_url

        if is_remote_url(self.model_path):
            logger.info("Pulling model configs from remote...")
            # BaseConnector implements __del__() to clean up the local dir.
            # Since config files need to exist all the time, so we DO NOT use
            # with statement to avoid closing the client.
            client = create_remote_connector(self.model_path)
            if is_remote_url(self.model_path):
                client.pull_files(allow_pattern=["*config.json"])
                self.model_weights = self.model_path
                self.model_path = client.get_local_dir()

Qubitium's avatar
Qubitium committed
653

654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: Union[str, torch.dtype],
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
HandH1998's avatar
HandH1998 committed
672
673
    if isinstance(config_dtype, str):
        config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
674
675
676
677
678
679
680
    if config_dtype is None:
        config_dtype = torch.float32

    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
681
682
683
684
685
                if config.model_type.startswith("gemma"):
                    if config.model_type == "gemma":
                        gemma_version = ""
                    else:
                        gemma_version = config.model_type[5]
686
                    logger.info(
687
                        f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16."
                    )
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
        else:
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
            pass
        else:
            # Casting between float16 and bfloat16 is allowed with a warning.
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

    return torch_dtype


724
725
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
Mick's avatar
Mick committed
726
    # 1. Check the model architecture
727
728
729
730
731
732
733
    # 2. check the `is_embedding` server args

    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
RangiLyu's avatar
RangiLyu committed
734
        or "InternLM2ForRewardModel" in model_architectures
735
        or "Qwen2ForRewardModel" in model_architectures
736
        or "Qwen2ForSequenceClassification" in model_architectures
737
        or "Qwen3ForSequenceClassification" in model_architectures
uylnap's avatar
uylnap committed
738
        or "CLIPModel" in model_architectures
woodx's avatar
woodx committed
739
740
741
742
743
        or "BertModel" in model_architectures
        or "Contriever" in model_architectures
        or "BertForSequenceClassification" in model_architectures
        or "XLMRobertaModel" in model_architectures
        or "XLMRobertaForSequenceClassification" in model_architectures
744
745
746
747
748
749
    ):
        return False
    else:
        return not is_embedding


Mick's avatar
Mick committed
750
multimodal_model_archs = [
751
    "CLIPModel",
752
    "DeepseekVL2ForCausalLM",
753
    "Gemma3ForConditionalGeneration",
754
    "Gemma3nForConditionalGeneration",
755
756
    "Glm4vForConditionalGeneration",
    "Glm4vMoeForConditionalGeneration",
Mick's avatar
Mick committed
757
758
    "Grok1VForCausalLM",
    "Grok1AForCausalLM",
Mick's avatar
Mick committed
759
    "LlavaLlamaForCausalLM",
Mick's avatar
Mick committed
760
    "Llama4ForConditionalGeneration",
Mick's avatar
Mick committed
761
762
    "LlavaMistralForCausalLM",
    "LlavaQwenForCausalLM",
Kiv Chen's avatar
Kiv Chen committed
763
    "LlavaForConditionalGeneration",
Mick's avatar
Mick committed
764
765
766
    "LlavaVidForCausalLM",
    "MiniCPMO",
    "MiniCPMV",
Kiv Chen's avatar
Kiv Chen committed
767
    "Mistral3ForConditionalGeneration",
Mick's avatar
Mick committed
768
    "MultiModalityCausalLM",
Mick's avatar
Mick committed
769
    "MllamaForConditionalGeneration",
Leng Yue's avatar
Leng Yue committed
770
    "Qwen2AudioForConditionalGeneration",
Mick's avatar
Mick committed
771
772
    "Qwen2VLForConditionalGeneration",
    "Qwen2_5_VLForConditionalGeneration",
773
774
    "Qwen3VLForConditionalGeneration",
    "Qwen3VLMoeForConditionalGeneration",
775
    "KimiVLForConditionalGeneration",
xm:D's avatar
xm:D committed
776
    "InternVLChatModel",
RunningLeon's avatar
RunningLeon committed
777
    "InternS1ForConditionalGeneration",
778
    "Phi4MMForCausalLM",
Zijian's avatar
Zijian committed
779
    "VILAForConditionalGeneration",
Chang Su's avatar
Chang Su committed
780
    "Step3VLForConditionalGeneration",
781
    "DotsVLMForCausalLM",
qrskannbara's avatar
qrskannbara committed
782
    "DotsOCRForCausalLM",
783
    "Sarashina2VisionForCausalLM",
Mick's avatar
Mick committed
784
785
786
]


787
def is_multimodal_model(model_architectures: List[str]):
Mick's avatar
Mick committed
788
789
790
    if any(
        multi_model_arch in model_architectures
        for multi_model_arch in multimodal_model_archs
791
792
793
794
795
796
    ):
        return True
    else:
        return False


797
798
799
800
801
802
803
804
805
806
807
808
def is_multimodal_gen_model(model_architectures: List[str]):
    return False


def is_image_gen_model(model_architectures: List[str]):
    return False


def is_audio_model(model_architectures: List[str]):
    return False


809
810
def is_encoder_decoder_model(model_architectures: List[str]):
    return "MllamaForConditionalGeneration" in model_architectures
811
812


813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
    """Check if chunked prefill is supported for a MultiModal model."""
    unsupported = [
        "Grok1VForCausalLM",
        "Grok1AForCausalLM",
        "LlavaLlamaForCausalLM",
        "MllamaForConditionalGeneration",
        "CLIPModel",
    ]
    if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
        return False
    else:
        return True


828
829
830
831
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0
tarinkk's avatar
tarinkk committed
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860


def is_hybrid_model(
    model_architectures: List[str],
    hybrid_kvcache_ratio: Optional[float],
    context_length: Optional[int],
    attention_chunk_size: Optional[int],
):
    if hybrid_kvcache_ratio is None:
        return None
    elif (
        hybrid_kvcache_ratio > 0
        and model_architectures[0] == "Llama4ForConditionalGeneration"
        and context_length > attention_chunk_size
    ):
        return hybrid_kvcache_ratio
    else:
        return None


def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
    if "Llama4ForConditionalGeneration" in model_architectures:
        swa_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
        ]
        full_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
        ]
    else:
Hanming Lu's avatar
Hanming Lu committed
861
862
        swa_attention_layer_ids = None
        full_attention_layer_ids = None
tarinkk's avatar
tarinkk committed
863
    return swa_attention_layer_ids, full_attention_layer_ids