model_config.py 33.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import logging
17
import math
18
import os
19
from enum import Enum, IntEnum, auto
20
from typing import List, Optional, Set, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21

22
import torch
Qubitium's avatar
Qubitium committed
23
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
26
27
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
Atream's avatar
Atream committed
28
    get_generation_config,
29
    get_hf_text_config,
30
    get_sparse_attention_config,
31
)
32
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
33
from sglang.srt.server_args import ServerArgs
34
from sglang.srt.utils import get_bool_env_var, is_hip, retry
35
from sglang.utils import is_in_ci
36

37
38
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
41
42
43
44
class AttentionArch(IntEnum):
    MLA = auto()
    MHA = auto()


45
46
47
48
49
50
class ModelImpl(str, Enum):
    AUTO = "auto"
    SGLANG = "sglang"
    TRANSFORMERS = "transformers"


Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
class ModelConfig:
    def __init__(
        self,
54
        model_path: str,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
57
        context_length: Optional[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
        model_override_args: str = "{}",
Chayenne's avatar
Chayenne committed
59
        is_embedding: Optional[bool] = None,
60
        enable_multimodal: Optional[bool] = None,
61
62
        dtype: str = "auto",
        quantization: Optional[str] = None,
63
        override_config_file: Optional[str] = None,
64
        is_draft_model: bool = False,
tarinkk's avatar
tarinkk committed
65
        hybrid_kvcache_ratio: Optional[float] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
66
        model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
Lianmin Zheng's avatar
Lianmin Zheng committed
67
    ) -> None:
68
        # Parse args
69
70
71
        self.model_path = model_path
        self.revision = revision
        self.quantization = quantization
72
        self.is_draft_model = is_draft_model
Lianmin Zheng's avatar
Lianmin Zheng committed
73
        self.model_impl = model_impl
74
75
76

        # Get hf config
        self._maybe_pull_model_tokenizer_from_remote()
77
        self.model_override_args = json.loads(model_override_args)
78
79
80
        kwargs = {}
        if override_config_file and override_config_file.strip():
            kwargs["_configuration_file"] = override_config_file.strip()
81
        self.hf_config = get_config(
82
            self.model_path,
83
84
85
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
86
            **kwargs,
87
        )
88
        self.hf_text_config = get_hf_text_config(self.hf_config)
Atream's avatar
Atream committed
89
90
91
92
93
94
95
        self.hf_generation_config = get_generation_config(
            self.model_path,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        # Set enable_multimodal
        if enable_multimodal is None:
            mm_disabled_models = [
                "Gemma3ForConditionalGeneration",
                "Llama4ForConditionalGeneration",
                "Step3VLForConditionalGeneration",
            ]
            if self.hf_config.architectures[0] in mm_disabled_models:
                enable_multimodal = False
                logger.info(
                    f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
                )
            else:
                enable_multimodal = True

        # Config draft model
        self._config_draft_model()

        # Check model type
Chang Su's avatar
Chang Su committed
115
116
117
        self.attention_chunk_size = getattr(
            self.hf_text_config, "attention_chunk_size", None
        )
tarinkk's avatar
tarinkk committed
118
119
120
121
122
123
124
125
126
127
128
129
        self.is_hybrid = is_hybrid_model(
            self.hf_config.architectures,
            hybrid_kvcache_ratio=hybrid_kvcache_ratio,
            context_length=context_length,
            attention_chunk_size=self.attention_chunk_size,
        )
        if self.is_hybrid is not None:
            self.swa_attention_layer_ids, self.full_attention_layer_ids = (
                get_hybrid_layer_ids(
                    self.hf_config.architectures, self.hf_text_config.num_hidden_layers
                )
            )
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        self.is_generation = is_generation_model(
            self.hf_config.architectures, is_embedding
        )
        self.is_multimodal = enable_multimodal and is_multimodal_model(
            self.hf_config.architectures
        )
        self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
            self.hf_config.architectures
        )
        self.is_image_gen = enable_multimodal and is_image_gen_model(
            self.hf_config.architectures
        )
        self.is_audio_model = enable_multimodal and is_audio_model(
            self.hf_config.architectures
        )
        self.is_multimodal_chunked_prefill_supported = (
            enable_multimodal
            and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
        )
        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        # Derive context length and model shapes
        self._derive_context_length(context_length)
        self._derive_model_shapes()

        # Verify quantization
        self._verify_quantization()

        # Verify dual-chunk attention config
        self._verify_dual_chunk_attention_config()

        # Cache attributes
        self.hf_eos_token_id = self._get_hf_eos_token_id()

        # multimodal
        self.image_token_id = getattr(
            self.hf_config, "image_token_id", None
        ) or getattr(self.hf_config, "image_token_index", None)

    @staticmethod
    def from_server_args(
        server_args: ServerArgs,
        model_path: str = None,
        model_revision: str = None,
        **kwargs,
    ):
        return ModelConfig(
            model_path=model_path or server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=model_revision or server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            enable_multimodal=server_args.enable_multimodal,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
            hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
            model_impl=server_args.model_impl,
            **kwargs,
        )

    def _config_draft_model(self):
        is_draft_model = self.is_draft_model
194

195
196
197
198
199
200
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
        ):
            self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"

Yuxuan Zhang's avatar
Yuxuan Zhang committed
201
202
203
        if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
            self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"

204
205
206
207
208
209
210
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
        ):
            self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
            self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers

211
212
        if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
            self.hf_config.architectures[0] = "MiMoMTP"
strgrb's avatar
strgrb committed
213
214
215
216
217
        if is_draft_model and self.hf_config.architectures[0] in [
            "BailingMoeV2ForCausalLM",
            "BailingMoeForCausalLM",
        ]:
            self.hf_config.architectures[0] = "BailingMoeForCausalLMNextN"
218
219
220
221
222
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
        ):
            self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
223

Yi Zhang's avatar
Yi Zhang committed
224
225
        if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
            self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
226
            self.hf_config.num_nextn_predict_layers = 1
Yi Zhang's avatar
Yi Zhang committed
227

228
229
    def _derive_context_length(self, context_length: int):
        is_draft_model = self.is_draft_model
230
        derived_context_len = get_context_length(self.hf_text_config)
231

232
        if context_length is not None:
233
            if context_length > derived_context_len:
234
235
236
237
238
239
240
241
                reason = "Target model's" if is_draft_model else "User-specified"
                msg = (
                    f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                    f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
                )
                if (
                    get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
                    or is_in_ci()  # FIXME: fix this special case
242
                ):
243
                    logger.warning(msg)
244
                    self.context_len = context_length
245
246
247
248
249
                    if is_draft_model:
                        self.hf_text_config.max_position_embeddings = context_length
                        logger.warning(
                            f"Overriding the draft model's max_position_embeddings to {context_length}."
                        )
250
251
                else:
                    raise ValueError(
252
                        f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
253
254
255
                    )
            else:
                self.context_len = context_length
256
        else:
257
            self.context_len = derived_context_len
Lianmin Zheng's avatar
Lianmin Zheng committed
258

259
260
261
262
        # Transfer context_len to HuggingFace config so models can access it
        self.hf_config.context_len = self.context_len

    def _derive_model_shapes(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
263
        # Unify the config keys for hf_text_config
Liangsheng Yin's avatar
Liangsheng Yin committed
264
        self.head_dim = getattr(
Liangsheng Yin's avatar
Liangsheng Yin committed
265
            self.hf_text_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
266
            "head_dim",
Liangsheng Yin's avatar
Liangsheng Yin committed
267
            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
Liangsheng Yin's avatar
Liangsheng Yin committed
268
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
269

270
        # FIXME: temporary special judge for MLA architecture
HandH1998's avatar
HandH1998 committed
271
272
273
        if (
            "DeepseekV2ForCausalLM" in self.hf_config.architectures
            or "DeepseekV3ForCausalLM" in self.hf_config.architectures
274
            or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
275
276
            or "LongcatFlashForCausalLM" in self.hf_config.architectures
            or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
277
            or "DotsVLMForCausalLM" in self.hf_config.architectures
HandH1998's avatar
HandH1998 committed
278
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
279
            self.head_dim = 256
280
281
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
282
            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
283
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
284
285
286
287
288
289
290
291
292
293
294
295
            self.v_head_dim = self.hf_config.v_head_dim

            # Handle rope scaling with yarn
            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            if self.hf_config.rope_scaling:
                mscale_all_dim = self.hf_config.rope_scaling.get(
                    "mscale_all_dim", False
                )
                scaling_factor = self.hf_config.rope_scaling["factor"]
                mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
                self.scaling = self.scaling * mscale * mscale

William's avatar
William committed
296
297
298
299
300
        elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 128
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
301
302
303
        elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
            self.hf_text_config, "use_mla", True
        ):
304
305
306
307
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
308
309
310
311
312
313
314
        elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
            self.v_head_dim = self.hf_text_config.v_head_dim
            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
315
        else:
316
317
318
            if (
                "MistralModel" in self.hf_config.architectures
                or "MixtralForCausalLM" in self.hf_config.architectures
319
                or "MistralForCausalLM" in self.hf_config.architectures
320
321
322
323
324
325
326
327
328
329
330
331
            ):
                if getattr(self, "head_dim", None) is None:
                    self.head_dim = (
                        self.hf_config.hidden_size // self.hf_config.num_attention_heads
                    )
                    # In transformers==4.52.3, the head_dim is null in MistralConfig
                    if (
                        not hasattr(self.hf_text_config, "head_dim")
                        or self.hf_text_config.head_dim is None
                    ):
                        setattr(self.hf_text_config, "head_dim", self.head_dim)

332
            self.attention_arch = AttentionArch.MHA
Liangsheng Yin's avatar
Liangsheng Yin committed
333

Liangsheng Yin's avatar
Liangsheng Yin committed
334
335
336
337
        self.num_attention_heads = self.hf_text_config.num_attention_heads
        self.num_key_value_heads = getattr(
            self.hf_text_config, "num_key_value_heads", None
        )
338
339
340
341
342
343
344

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

345
346
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Liangsheng Yin's avatar
Liangsheng Yin committed
347
348
        self.hidden_size = self.hf_text_config.hidden_size
        self.num_hidden_layers = self.hf_text_config.num_hidden_layers
349
350
351
        self.num_attention_layers = self.num_hidden_layers
        if "LongcatFlashForCausalLM" in self.hf_config.architectures:
            self.num_attention_layers = self.num_hidden_layers * 2
352
353
354
        self.num_nextn_predict_layers = getattr(
            self.hf_text_config, "num_nextn_predict_layers", None
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
355
        self.vocab_size = self.hf_text_config.vocab_size
Qubitium's avatar
Qubitium committed
356

357
358
359
360
361
362
363
    def get_total_num_attention_heads(self) -> int:
        return self.num_attention_heads

    def get_num_attention_heads(self, tensor_parallel_size) -> int:
        total_num_attention_heads = self.num_attention_heads
        return max(1, total_num_attention_heads // tensor_parallel_size)

Qubitium's avatar
Qubitium committed
364
365
366
367
368
369
370
371
372
    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
373
374
375
376
377
378
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
379
380
381
382
383
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
384
385
386
387
388
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
389
390
391
392
393
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
394
395
396
397
398
399
400
401
402
403
404
405
406
        if self.hf_config.model_type in ["nemotron-nas"]:
            nkvh = {
                self.hf_config.num_attention_heads // block.attention.n_heads_in_group
                for block in self.hf_config.block_configs
                if not block.attention.no_op
            }
            if len(nkvh) == 0:
                raise RuntimeError("Couldn't determine number of kv heads")
            if len(nkvh) > 1:
                raise ValueError(
                    "Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang"
                )
            return next(iter(nkvh))
Qubitium's avatar
Qubitium committed
407
408
409
410
411
412
413
414
415

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
Chang Su's avatar
Chang Su committed
416
417
            # For Step3
            "num_attention_groups",
Qubitium's avatar
Qubitium committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
435
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
436

437
438
439
440
441
442
    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
            # compressed-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
443
        if quant_cfg is None:
444
            # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
445
446
            # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
            # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
447
            # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
448
            is_local = os.path.exists(self.model_path)
449
450
            modelopt_quant_config = {"quant_method": "modelopt"}
            if not is_local:
451
452
453
454
455
456
                import huggingface_hub

                try:
                    from huggingface_hub import HfApi

                    hf_api = HfApi()
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471

                    def check_hf_quant_config():
                        return hf_api.file_exists(
                            self.model_path, "hf_quant_config.json"
                        )

                    # Retry HF API call up to 3 times
                    file_exists = retry(
                        check_hf_quant_config,
                        max_retry=2,
                        initial_delay=1.0,
                        max_delay=5.0,
                    )

                    if file_exists:
472
                        quant_cfg = modelopt_quant_config
473

474
475
476
477
                except huggingface_hub.errors.OfflineModeIsEnabled:
                    logger.warning(
                        "Offline mode is enabled, skipping hf_quant_config.json check"
                    )
478
479
480
481
                except Exception as e:
                    logger.warning(
                        f"Failed to check hf_quant_config.json: {self.model_path} {e}"
                    )
482
483

            elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
484
485
486
487
488
489
490
491
492
493
494
                quant_config_file = os.path.join(
                    self.model_path, "hf_quant_config.json"
                )
                with open(quant_config_file) as f:
                    quant_config_dict = json.load(f)
                json_quant_configs = quant_config_dict["quantization"]
                quant_algo = json_quant_configs.get("quant_algo", None)
                if quant_algo == "MIXED_PRECISION":
                    quant_cfg = {"quant_method": "w4afp8"}
                else:
                    quant_cfg = modelopt_quant_config
495
496
497
498
        return quant_cfg

    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _verify_quantization(self) -> None:
499
        supported_quantization = [*QUANTIZATION_METHODS]
500
501
502
503
504
505
506
        rocm_supported_quantization = [
            "awq",
            "gptq",
            "fp8",
            "compressed_tensors",
            "compressed-tensors",
            "fbgemm_fp8",
HandH1998's avatar
HandH1998 committed
507
            "w8a8_fp8",
508
            "petit_nvfp4",
509
510
            "quark",
            "mxfp4",
511
512
513
514
515
516
517
518
519
520
521
522
        ]
        optimized_quantization_methods = [
            "fp8",
            "marlin",
            "modelopt",
            "gptq_marlin_24",
            "gptq_marlin",
            "awq_marlin",
            "fbgemm_fp8",
            "compressed_tensors",
            "compressed-tensors",
            "experts_int8",
523
            "w8a8_int8",
HandH1998's avatar
HandH1998 committed
524
            "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
525
            "moe_wna16",
HandH1998's avatar
HandH1998 committed
526
            "qoq",
527
            "w4afp8",
528
            "petit_nvfp4",
529
        ]
530
        compatible_quantization_methods = {
531
            "modelopt_fp4": ["modelopt"],
532
            "petit_nvfp4": ["modelopt"],
HandH1998's avatar
HandH1998 committed
533
534
            "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
            "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
535
        }
536
537
538
539
540
541
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        quant_cfg = self._parse_quant_hf_config()

542
        if quant_cfg is not None:
543
544
545
            quant_method = quant_cfg.get(
                "quant_method", "" if not self.quantization else self.quantization
            ).lower()
546
547

            # Detect which checkpoint is it
548
            for _, method in QUANTIZATION_METHODS.items():
549
550
551
552
553
554
555
556
557
558
559
560
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization
                )
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break

            # Verify quantization configurations.
            if self.quantization is None:
                self.quantization = quant_method
            elif self.quantization != quant_method:
561
562
563
564
565
566
567
568
569
570
571
                if (
                    self.quantization not in compatible_quantization_methods
                    or quant_method
                    not in compatible_quantization_methods[self.quantization]
                ):
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization})."
                    )
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}."
                )
            if is_hip() and self.quantization not in rocm_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm."
                )
            if self.quantization not in optimized_quantization_methods:
                logger.warning(
                    "%s quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.",
                    self.quantization,
                )

592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    def _verify_dual_chunk_attention_config(self) -> None:
        if hasattr(self.hf_config, "dual_chunk_attention_config"):
            # Try loading the sparse attention config
            sparse_attn_config = get_sparse_attention_config(self.model_path)
            if not sparse_attn_config:
                return
            self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
                sparse_attn_config
            )
            if (
                "sparse_attention_enabled"
                not in self.hf_config.dual_chunk_attention_config
            ):
                self.hf_config.dual_chunk_attention_config[
                    "sparse_attention_enabled"
                ] = True

609
    def _get_hf_eos_token_id(self) -> Optional[Set[int]]:
610
        eos_ids = getattr(self.hf_config, "eos_token_id", None)
Minho Ryu's avatar
Minho Ryu committed
611
        if eos_ids is not None:
612
613
            # it can be either int or list of int
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
Atream's avatar
Atream committed
614
615
616
617
618
619
620
621
622
623
624
625
626
        if eos_ids is None:
            eos_ids = set()
        if self.hf_generation_config:
            generation_eos_ids = getattr(
                self.hf_generation_config, "eos_token_id", None
            )
            if generation_eos_ids:
                generation_eos_ids = (
                    {generation_eos_ids}
                    if isinstance(generation_eos_ids, int)
                    else set(generation_eos_ids)
                )
                eos_ids = eos_ids | generation_eos_ids
627
628
        return eos_ids

629
    def _maybe_pull_model_tokenizer_from_remote(self) -> None:
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        """
        Pull the model config files to a temporary
        directory in case of remote.

        Args:
            model: The model name or path.

        """
        from sglang.srt.connector import create_remote_connector
        from sglang.srt.utils import is_remote_url

        if is_remote_url(self.model_path):
            logger.info("Pulling model configs from remote...")
            # BaseConnector implements __del__() to clean up the local dir.
            # Since config files need to exist all the time, so we DO NOT use
            # with statement to avoid closing the client.
            client = create_remote_connector(self.model_path)
            if is_remote_url(self.model_path):
                client.pull_files(allow_pattern=["*config.json"])
                self.model_weights = self.model_path
                self.model_path = client.get_local_dir()

Qubitium's avatar
Qubitium committed
652

653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: Union[str, torch.dtype],
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
HandH1998's avatar
HandH1998 committed
671
672
    if isinstance(config_dtype, str):
        config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
673
674
675
676
677
678
679
    if config_dtype is None:
        config_dtype = torch.float32

    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
680
681
682
683
684
                if config.model_type.startswith("gemma"):
                    if config.model_type == "gemma":
                        gemma_version = ""
                    else:
                        gemma_version = config.model_type[5]
685
                    logger.info(
686
                        f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16."
                    )
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
        else:
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
            pass
        else:
            # Casting between float16 and bfloat16 is allowed with a warning.
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

    return torch_dtype


723
724
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
Mick's avatar
Mick committed
725
    # 1. Check the model architecture
726
727
728
729
730
731
732
    # 2. check the `is_embedding` server args

    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
RangiLyu's avatar
RangiLyu committed
733
        or "InternLM2ForRewardModel" in model_architectures
734
        or "Qwen2ForRewardModel" in model_architectures
735
        or "Qwen2ForSequenceClassification" in model_architectures
736
        or "Qwen3ForSequenceClassification" in model_architectures
uylnap's avatar
uylnap committed
737
        or "CLIPModel" in model_architectures
woodx's avatar
woodx committed
738
739
740
741
742
        or "BertModel" in model_architectures
        or "Contriever" in model_architectures
        or "BertForSequenceClassification" in model_architectures
        or "XLMRobertaModel" in model_architectures
        or "XLMRobertaForSequenceClassification" in model_architectures
743
744
745
746
747
748
    ):
        return False
    else:
        return not is_embedding


Mick's avatar
Mick committed
749
multimodal_model_archs = [
750
    "CLIPModel",
751
    "DeepseekVL2ForCausalLM",
752
    "Gemma3ForConditionalGeneration",
753
    "Gemma3nForConditionalGeneration",
754
755
    "Glm4vForConditionalGeneration",
    "Glm4vMoeForConditionalGeneration",
Mick's avatar
Mick committed
756
757
    "Grok1VForCausalLM",
    "Grok1AForCausalLM",
Mick's avatar
Mick committed
758
    "LlavaLlamaForCausalLM",
Mick's avatar
Mick committed
759
    "Llama4ForConditionalGeneration",
Mick's avatar
Mick committed
760
761
    "LlavaMistralForCausalLM",
    "LlavaQwenForCausalLM",
Kiv Chen's avatar
Kiv Chen committed
762
    "LlavaForConditionalGeneration",
Mick's avatar
Mick committed
763
764
765
    "LlavaVidForCausalLM",
    "MiniCPMO",
    "MiniCPMV",
Kiv Chen's avatar
Kiv Chen committed
766
    "Mistral3ForConditionalGeneration",
Mick's avatar
Mick committed
767
    "MultiModalityCausalLM",
Mick's avatar
Mick committed
768
    "MllamaForConditionalGeneration",
Leng Yue's avatar
Leng Yue committed
769
    "Qwen2AudioForConditionalGeneration",
Mick's avatar
Mick committed
770
771
    "Qwen2VLForConditionalGeneration",
    "Qwen2_5_VLForConditionalGeneration",
772
773
    "Qwen3VLForConditionalGeneration",
    "Qwen3VLMoeForConditionalGeneration",
774
    "KimiVLForConditionalGeneration",
xm:D's avatar
xm:D committed
775
    "InternVLChatModel",
RunningLeon's avatar
RunningLeon committed
776
    "InternS1ForConditionalGeneration",
777
    "Phi4MMForCausalLM",
Zijian's avatar
Zijian committed
778
    "VILAForConditionalGeneration",
Chang Su's avatar
Chang Su committed
779
    "Step3VLForConditionalGeneration",
780
    "DotsVLMForCausalLM",
qrskannbara's avatar
qrskannbara committed
781
    "DotsOCRForCausalLM",
782
    "Sarashina2VisionForCausalLM",
Mick's avatar
Mick committed
783
784
785
]


786
def is_multimodal_model(model_architectures: List[str]):
Mick's avatar
Mick committed
787
788
789
    if any(
        multi_model_arch in model_architectures
        for multi_model_arch in multimodal_model_archs
790
791
792
793
794
795
    ):
        return True
    else:
        return False


796
797
798
799
800
801
802
803
804
805
806
807
def is_multimodal_gen_model(model_architectures: List[str]):
    return False


def is_image_gen_model(model_architectures: List[str]):
    return False


def is_audio_model(model_architectures: List[str]):
    return False


808
809
def is_encoder_decoder_model(model_architectures: List[str]):
    return "MllamaForConditionalGeneration" in model_architectures
810
811


812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
    """Check if chunked prefill is supported for a MultiModal model."""
    unsupported = [
        "Grok1VForCausalLM",
        "Grok1AForCausalLM",
        "LlavaLlamaForCausalLM",
        "MllamaForConditionalGeneration",
        "CLIPModel",
    ]
    if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
        return False
    else:
        return True


827
828
829
830
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0
tarinkk's avatar
tarinkk committed
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859


def is_hybrid_model(
    model_architectures: List[str],
    hybrid_kvcache_ratio: Optional[float],
    context_length: Optional[int],
    attention_chunk_size: Optional[int],
):
    if hybrid_kvcache_ratio is None:
        return None
    elif (
        hybrid_kvcache_ratio > 0
        and model_architectures[0] == "Llama4ForConditionalGeneration"
        and context_length > attention_chunk_size
    ):
        return hybrid_kvcache_ratio
    else:
        return None


def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
    if "Llama4ForConditionalGeneration" in model_architectures:
        swa_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
        ]
        full_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
        ]
    else:
Hanming Lu's avatar
Hanming Lu committed
860
861
        swa_attention_layer_ids = None
        full_attention_layer_ids = None
tarinkk's avatar
tarinkk committed
862
    return swa_attention_layer_ids, full_attention_layer_ids