model_config.py 33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import logging
17
import math
18
import os
19
from enum import Enum, IntEnum, auto
20
from typing import List, Optional, Set, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21

22
import torch
Qubitium's avatar
Qubitium committed
23
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
26
27
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
Atream's avatar
Atream committed
28
    get_generation_config,
29
    get_hf_text_config,
30
    get_sparse_attention_config,
31
)
32
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
33
from sglang.srt.server_args import ServerArgs
34
from sglang.srt.utils import get_bool_env_var, is_hip
35
from sglang.utils import is_in_ci
36

37
38
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
41
42
43
44
class AttentionArch(IntEnum):
    MLA = auto()
    MHA = auto()


45
46
47
48
49
50
class ModelImpl(str, Enum):
    AUTO = "auto"
    SGLANG = "sglang"
    TRANSFORMERS = "transformers"


Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
class ModelConfig:
    def __init__(
        self,
54
        model_path: str,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
57
        context_length: Optional[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
        model_override_args: str = "{}",
Chayenne's avatar
Chayenne committed
59
        is_embedding: Optional[bool] = None,
60
        enable_multimodal: Optional[bool] = None,
61
62
        dtype: str = "auto",
        quantization: Optional[str] = None,
63
        override_config_file: Optional[str] = None,
64
        is_draft_model: bool = False,
tarinkk's avatar
tarinkk committed
65
        hybrid_kvcache_ratio: Optional[float] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
66
        model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
67
68
69
70
71
72
        tp_rank: Optional[int] = None,
        remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
        remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
        remote_instance_weight_loader_send_weights_group_ports: Optional[
            List[int]
        ] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
73
    ) -> None:
74
        # Parse args
75
76
77
        self.model_path = model_path
        self.revision = revision
        self.quantization = quantization
Lianmin Zheng's avatar
Lianmin Zheng committed
78
        self.model_impl = model_impl
79
80
81
82
83
84
85
86
87
88
        self.tp_rank = tp_rank
        self.remote_instance_weight_loader_seed_instance_ip = (
            remote_instance_weight_loader_seed_instance_ip
        )
        self.remote_instance_weight_loader_seed_instance_service_port = (
            remote_instance_weight_loader_seed_instance_service_port
        )
        self.remote_instance_weight_loader_send_weights_group_ports = (
            remote_instance_weight_loader_send_weights_group_ports
        )
89

90
        self.maybe_pull_model_tokenizer_from_remote()
91
        self.model_override_args = json.loads(model_override_args)
92
93
94
95
        kwargs = {}
        if override_config_file and override_config_file.strip():
            kwargs["_configuration_file"] = override_config_file.strip()

96
        self.hf_config = get_config(
97
            self.model_path,
98
99
100
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
101
            **kwargs,
102
        )
103

Atream's avatar
Atream committed
104
105
106
107
108
109
110
        self.hf_generation_config = get_generation_config(
            self.model_path,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )

Qubitium's avatar
Qubitium committed
111
        self.hf_text_config = get_hf_text_config(self.hf_config)
Chang Su's avatar
Chang Su committed
112
113
114
        self.attention_chunk_size = getattr(
            self.hf_text_config, "attention_chunk_size", None
        )
tarinkk's avatar
tarinkk committed
115
116
117
118
119
120
121
122
123
124
125
126
        self.is_hybrid = is_hybrid_model(
            self.hf_config.architectures,
            hybrid_kvcache_ratio=hybrid_kvcache_ratio,
            context_length=context_length,
            attention_chunk_size=self.attention_chunk_size,
        )
        if self.is_hybrid is not None:
            self.swa_attention_layer_ids, self.full_attention_layer_ids = (
                get_hybrid_layer_ids(
                    self.hf_config.architectures, self.hf_text_config.num_hidden_layers
                )
            )
127

128
        if enable_multimodal is None:
129
130
131
            mm_disabled_models = [
                "Gemma3ForConditionalGeneration",
                "Llama4ForConditionalGeneration",
Ke Bao's avatar
Ke Bao committed
132
                "Step3VLForConditionalGeneration",
133
134
            ]
            if self.hf_config.architectures[0] in mm_disabled_models:
135
                enable_multimodal = False
136
                logger.info(
137
                    f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
138
                )
139
140
141
            else:
                enable_multimodal = True

142
143
144
145
146
147
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
        ):
            self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"

Yuxuan Zhang's avatar
Yuxuan Zhang committed
148
149
150
        if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
            self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"

151
152
153
154
155
156
157
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
        ):
            self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
            self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers

158
159
        if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
            self.hf_config.architectures[0] = "MiMoMTP"
strgrb's avatar
strgrb committed
160
161
162
163
164
        if is_draft_model and self.hf_config.architectures[0] in [
            "BailingMoeV2ForCausalLM",
            "BailingMoeForCausalLM",
        ]:
            self.hf_config.architectures[0] = "BailingMoeForCausalLMNextN"
165
166
167
168
169
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
        ):
            self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
170

Yi Zhang's avatar
Yi Zhang committed
171
172
        if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
            self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
173
            self.hf_config.num_nextn_predict_layers = 1
Yi Zhang's avatar
Yi Zhang committed
174

175
        # Check model type
Chayenne's avatar
Chayenne committed
176
177
178
        self.is_generation = is_generation_model(
            self.hf_config.architectures, is_embedding
        )
179
180
181
182
183
184
185
186
187
188
189
190
        self.is_multimodal = enable_multimodal and is_multimodal_model(
            self.hf_config.architectures
        )
        self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
            self.hf_config.architectures
        )
        self.is_image_gen = enable_multimodal and is_image_gen_model(
            self.hf_config.architectures
        )
        self.is_audio_model = enable_multimodal and is_audio_model(
            self.hf_config.architectures
        )
191
192
193
194
        self.is_multimodal_chunked_prefill_supported = (
            enable_multimodal
            and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
        )
195
        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
196
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
197
198

        # Derive context length
199
        derived_context_len = get_context_length(self.hf_text_config)
200
        if context_length is not None:
201
            if context_length > derived_context_len:
202
203
204
205
206
207
208
209
                reason = "Target model's" if is_draft_model else "User-specified"
                msg = (
                    f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                    f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
                )
                if (
                    get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
                    or is_in_ci()  # FIXME: fix this special case
210
                ):
211
                    logger.warning(msg)
212
213
214
                    self.context_len = context_length
                else:
                    raise ValueError(
215
                        f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
216
217
218
                    )
            else:
                self.context_len = context_length
219
        else:
220
            self.context_len = derived_context_len
Lianmin Zheng's avatar
Lianmin Zheng committed
221

Liangsheng Yin's avatar
Liangsheng Yin committed
222
        # Unify the config keys for hf_text_config
Liangsheng Yin's avatar
Liangsheng Yin committed
223
        self.head_dim = getattr(
Liangsheng Yin's avatar
Liangsheng Yin committed
224
            self.hf_text_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
225
            "head_dim",
Liangsheng Yin's avatar
Liangsheng Yin committed
226
            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
Liangsheng Yin's avatar
Liangsheng Yin committed
227
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
228

229
        # FIXME: temporary special judge for MLA architecture
HandH1998's avatar
HandH1998 committed
230
231
232
        if (
            "DeepseekV2ForCausalLM" in self.hf_config.architectures
            or "DeepseekV3ForCausalLM" in self.hf_config.architectures
233
            or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
234
235
            or "LongcatFlashForCausalLM" in self.hf_config.architectures
            or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
236
            or "DotsVLMForCausalLM" in self.hf_config.architectures
HandH1998's avatar
HandH1998 committed
237
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
238
            self.head_dim = 256
239
240
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
241
            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
242
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
243
244
245
246
247
248
249
250
251
252
253
254
            self.v_head_dim = self.hf_config.v_head_dim

            # Handle rope scaling with yarn
            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            if self.hf_config.rope_scaling:
                mscale_all_dim = self.hf_config.rope_scaling.get(
                    "mscale_all_dim", False
                )
                scaling_factor = self.hf_config.rope_scaling["factor"]
                mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
                self.scaling = self.scaling * mscale * mscale

William's avatar
William committed
255
256
257
258
259
        elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 128
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
260
261
262
        elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
            self.hf_text_config, "use_mla", True
        ):
263
264
265
266
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
267
268
269
270
271
272
273
        elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
            self.v_head_dim = self.hf_text_config.v_head_dim
            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
274
        else:
275
276
277
            if (
                "MistralModel" in self.hf_config.architectures
                or "MixtralForCausalLM" in self.hf_config.architectures
278
                or "MistralForCausalLM" in self.hf_config.architectures
279
280
281
282
283
284
285
286
287
288
289
290
            ):
                if getattr(self, "head_dim", None) is None:
                    self.head_dim = (
                        self.hf_config.hidden_size // self.hf_config.num_attention_heads
                    )
                    # In transformers==4.52.3, the head_dim is null in MistralConfig
                    if (
                        not hasattr(self.hf_text_config, "head_dim")
                        or self.hf_text_config.head_dim is None
                    ):
                        setattr(self.hf_text_config, "head_dim", self.head_dim)

291
            self.attention_arch = AttentionArch.MHA
Liangsheng Yin's avatar
Liangsheng Yin committed
292

Liangsheng Yin's avatar
Liangsheng Yin committed
293
294
295
296
        self.num_attention_heads = self.hf_text_config.num_attention_heads
        self.num_key_value_heads = getattr(
            self.hf_text_config, "num_key_value_heads", None
        )
297
298
299
300
301
302
303

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

304
305
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Liangsheng Yin's avatar
Liangsheng Yin committed
306
307
        self.hidden_size = self.hf_text_config.hidden_size
        self.num_hidden_layers = self.hf_text_config.num_hidden_layers
308
309
310
        self.num_attention_layers = self.num_hidden_layers
        if "LongcatFlashForCausalLM" in self.hf_config.architectures:
            self.num_attention_layers = self.num_hidden_layers * 2
311
312
313
        self.num_nextn_predict_layers = getattr(
            self.hf_text_config, "num_nextn_predict_layers", None
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
314
        self.vocab_size = self.hf_text_config.vocab_size
Qubitium's avatar
Qubitium committed
315

316
        # Verify quantization
317
318
        self._verify_quantization()

319
320
321
        # Verify dual-chunk attention config
        self._verify_dual_chunk_attention_config()

322
        # Cache attributes
323
        self.hf_eos_token_id = self.get_hf_eos_token_id()
324
325

        # multimodal
326
327
328
        self.image_token_id = getattr(
            self.hf_config, "image_token_id", None
        ) or getattr(self.hf_config, "image_token_index", None)
329

330
    @staticmethod
331
332
333
334
335
336
    def from_server_args(
        server_args: ServerArgs,
        model_path: str = None,
        model_revision: str = None,
        **kwargs,
    ):
337
338
339
        return ModelConfig(
            model_path=model_path or server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
340
            revision=model_revision or server_args.revision,
341
342
343
344
345
346
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            enable_multimodal=server_args.enable_multimodal,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
tarinkk's avatar
tarinkk committed
347
            hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
Lianmin Zheng's avatar
Lianmin Zheng committed
348
            model_impl=server_args.model_impl,
349
350
351
            remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
            remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
            remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
352
353
354
            **kwargs,
        )

355
356
357
358
359
360
361
    def get_total_num_attention_heads(self) -> int:
        return self.num_attention_heads

    def get_num_attention_heads(self, tensor_parallel_size) -> int:
        total_num_attention_heads = self.num_attention_heads
        return max(1, total_num_attention_heads // tensor_parallel_size)

Qubitium's avatar
Qubitium committed
362
363
364
365
366
367
368
369
370
    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
371
372
373
374
375
376
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
377
378
379
380
381
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
382
383
384
385
386
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
387
388
389
390
391
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
392
393
394
395
396
397
398
399
400
401
402
403
404
        if self.hf_config.model_type in ["nemotron-nas"]:
            nkvh = {
                self.hf_config.num_attention_heads // block.attention.n_heads_in_group
                for block in self.hf_config.block_configs
                if not block.attention.no_op
            }
            if len(nkvh) == 0:
                raise RuntimeError("Couldn't determine number of kv heads")
            if len(nkvh) > 1:
                raise ValueError(
                    "Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang"
                )
            return next(iter(nkvh))
Qubitium's avatar
Qubitium committed
405
406
407
408
409
410
411
412
413

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
Chang Su's avatar
Chang Su committed
414
415
            # For Step3
            "num_attention_groups",
Qubitium's avatar
Qubitium committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
433
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
434

435
436
437
438
439
440
    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
            # compressed-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
441
        if quant_cfg is None:
442
            # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
443
444
            # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
            # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
445
            # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
446
            is_local = os.path.exists(self.model_path)
447
448
            modelopt_quant_config = {"quant_method": "modelopt"}
            if not is_local:
449
450
451
452
453
454
455
456
457
458
459
460
461
                import huggingface_hub

                try:
                    from huggingface_hub import HfApi

                    hf_api = HfApi()
                    if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
                        quant_cfg = modelopt_quant_config
                except huggingface_hub.errors.OfflineModeIsEnabled:
                    logger.warning(
                        "Offline mode is enabled, skipping hf_quant_config.json check"
                    )
                    pass
462
463

            elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
464
465
466
467
468
469
470
471
472
473
474
                quant_config_file = os.path.join(
                    self.model_path, "hf_quant_config.json"
                )
                with open(quant_config_file) as f:
                    quant_config_dict = json.load(f)
                json_quant_configs = quant_config_dict["quantization"]
                quant_algo = json_quant_configs.get("quant_algo", None)
                if quant_algo == "MIXED_PRECISION":
                    quant_cfg = {"quant_method": "w4afp8"}
                else:
                    quant_cfg = modelopt_quant_config
475
476
477
478
        return quant_cfg

    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _verify_quantization(self) -> None:
479
        supported_quantization = [*QUANTIZATION_METHODS]
480
481
482
483
484
485
486
        rocm_supported_quantization = [
            "awq",
            "gptq",
            "fp8",
            "compressed_tensors",
            "compressed-tensors",
            "fbgemm_fp8",
HandH1998's avatar
HandH1998 committed
487
            "w8a8_fp8",
488
            "petit_nvfp4",
489
490
            "quark",
            "mxfp4",
491
492
493
494
495
496
497
498
499
500
501
502
        ]
        optimized_quantization_methods = [
            "fp8",
            "marlin",
            "modelopt",
            "gptq_marlin_24",
            "gptq_marlin",
            "awq_marlin",
            "fbgemm_fp8",
            "compressed_tensors",
            "compressed-tensors",
            "experts_int8",
503
            "w8a8_int8",
HandH1998's avatar
HandH1998 committed
504
            "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
505
            "moe_wna16",
HandH1998's avatar
HandH1998 committed
506
            "qoq",
507
            "w4afp8",
508
            "petit_nvfp4",
509
        ]
510
        compatible_quantization_methods = {
511
            "modelopt_fp4": ["modelopt"],
512
            "petit_nvfp4": ["modelopt"],
HandH1998's avatar
HandH1998 committed
513
514
            "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
            "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
515
        }
516
517
518
519
520
521
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        quant_cfg = self._parse_quant_hf_config()

522
        if quant_cfg is not None:
523
524
525
            quant_method = quant_cfg.get(
                "quant_method", "" if not self.quantization else self.quantization
            ).lower()
526
527

            # Detect which checkpoint is it
528
            for _, method in QUANTIZATION_METHODS.items():
529
530
531
532
533
534
535
536
537
538
539
540
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization
                )
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break

            # Verify quantization configurations.
            if self.quantization is None:
                self.quantization = quant_method
            elif self.quantization != quant_method:
541
542
543
544
545
546
547
548
549
550
551
                if (
                    self.quantization not in compatible_quantization_methods
                    or quant_method
                    not in compatible_quantization_methods[self.quantization]
                ):
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization})."
                    )
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}."
                )
            if is_hip() and self.quantization not in rocm_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm."
                )
            if self.quantization not in optimized_quantization_methods:
                logger.warning(
                    "%s quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.",
                    self.quantization,
                )

572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    def _verify_dual_chunk_attention_config(self) -> None:
        if hasattr(self.hf_config, "dual_chunk_attention_config"):
            # Try loading the sparse attention config
            sparse_attn_config = get_sparse_attention_config(self.model_path)
            if not sparse_attn_config:
                return
            self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
                sparse_attn_config
            )
            if (
                "sparse_attention_enabled"
                not in self.hf_config.dual_chunk_attention_config
            ):
                self.hf_config.dual_chunk_attention_config[
                    "sparse_attention_enabled"
                ] = True

589
590
    def get_hf_eos_token_id(self) -> Optional[Set[int]]:
        eos_ids = getattr(self.hf_config, "eos_token_id", None)
Minho Ryu's avatar
Minho Ryu committed
591
        if eos_ids is not None:
592
593
            # it can be either int or list of int
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
Atream's avatar
Atream committed
594
595
596
597
598
599
600
601
602
603
604
605
606
        if eos_ids is None:
            eos_ids = set()
        if self.hf_generation_config:
            generation_eos_ids = getattr(
                self.hf_generation_config, "eos_token_id", None
            )
            if generation_eos_ids:
                generation_eos_ids = (
                    {generation_eos_ids}
                    if isinstance(generation_eos_ids, int)
                    else set(generation_eos_ids)
                )
                eos_ids = eos_ids | generation_eos_ids
607
608
        return eos_ids

609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    def maybe_pull_model_tokenizer_from_remote(self) -> None:
        """
        Pull the model config files to a temporary
        directory in case of remote.

        Args:
            model: The model name or path.

        """
        from sglang.srt.connector import create_remote_connector
        from sglang.srt.utils import is_remote_url

        if is_remote_url(self.model_path):
            logger.info("Pulling model configs from remote...")
            # BaseConnector implements __del__() to clean up the local dir.
            # Since config files need to exist all the time, so we DO NOT use
            # with statement to avoid closing the client.
            client = create_remote_connector(self.model_path)
            if is_remote_url(self.model_path):
                client.pull_files(allow_pattern=["*config.json"])
                self.model_weights = self.model_path
                self.model_path = client.get_local_dir()

Qubitium's avatar
Qubitium committed
632

633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: Union[str, torch.dtype],
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
HandH1998's avatar
HandH1998 committed
651
652
    if isinstance(config_dtype, str):
        config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
653
654
655
656
657
658
659
    if config_dtype is None:
        config_dtype = torch.float32

    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
660
661
662
663
664
                if config.model_type.startswith("gemma"):
                    if config.model_type == "gemma":
                        gemma_version = ""
                    else:
                        gemma_version = config.model_type[5]
665
                    logger.info(
666
                        f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16."
                    )
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
        else:
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
            pass
        else:
            # Casting between float16 and bfloat16 is allowed with a warning.
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

    return torch_dtype


703
704
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
Mick's avatar
Mick committed
705
    # 1. Check the model architecture
706
707
708
709
710
711
712
    # 2. check the `is_embedding` server args

    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
RangiLyu's avatar
RangiLyu committed
713
        or "InternLM2ForRewardModel" in model_architectures
714
        or "Qwen2ForRewardModel" in model_architectures
715
        or "Qwen2ForSequenceClassification" in model_architectures
716
        or "Qwen3ForSequenceClassification" in model_architectures
uylnap's avatar
uylnap committed
717
        or "CLIPModel" in model_architectures
woodx's avatar
woodx committed
718
719
720
721
722
        or "BertModel" in model_architectures
        or "Contriever" in model_architectures
        or "BertForSequenceClassification" in model_architectures
        or "XLMRobertaModel" in model_architectures
        or "XLMRobertaForSequenceClassification" in model_architectures
723
724
725
726
727
728
    ):
        return False
    else:
        return not is_embedding


Mick's avatar
Mick committed
729
multimodal_model_archs = [
730
    "CLIPModel",
731
    "DeepseekVL2ForCausalLM",
732
    "Gemma3ForConditionalGeneration",
733
    "Gemma3nForConditionalGeneration",
734
735
    "Glm4vForConditionalGeneration",
    "Glm4vMoeForConditionalGeneration",
Mick's avatar
Mick committed
736
737
    "Grok1VForCausalLM",
    "Grok1AForCausalLM",
Mick's avatar
Mick committed
738
    "LlavaLlamaForCausalLM",
Mick's avatar
Mick committed
739
    "Llama4ForConditionalGeneration",
Mick's avatar
Mick committed
740
741
    "LlavaMistralForCausalLM",
    "LlavaQwenForCausalLM",
Kiv Chen's avatar
Kiv Chen committed
742
    "LlavaForConditionalGeneration",
Mick's avatar
Mick committed
743
744
745
    "LlavaVidForCausalLM",
    "MiniCPMO",
    "MiniCPMV",
Kiv Chen's avatar
Kiv Chen committed
746
    "Mistral3ForConditionalGeneration",
Mick's avatar
Mick committed
747
    "MultiModalityCausalLM",
Mick's avatar
Mick committed
748
    "MllamaForConditionalGeneration",
Leng Yue's avatar
Leng Yue committed
749
    "Qwen2AudioForConditionalGeneration",
Mick's avatar
Mick committed
750
751
    "Qwen2VLForConditionalGeneration",
    "Qwen2_5_VLForConditionalGeneration",
752
    "KimiVLForConditionalGeneration",
xm:D's avatar
xm:D committed
753
    "InternVLChatModel",
RunningLeon's avatar
RunningLeon committed
754
    "InternS1ForConditionalGeneration",
755
    "Phi4MMForCausalLM",
Zijian's avatar
Zijian committed
756
    "VILAForConditionalGeneration",
Chang Su's avatar
Chang Su committed
757
    "Step3VLForConditionalGeneration",
758
    "DotsVLMForCausalLM",
Mick's avatar
Mick committed
759
760
761
]


762
def is_multimodal_model(model_architectures: List[str]):
Mick's avatar
Mick committed
763
764
765
    if any(
        multi_model_arch in model_architectures
        for multi_model_arch in multimodal_model_archs
766
767
768
769
770
771
    ):
        return True
    else:
        return False


772
773
774
775
776
777
778
779
780
781
782
783
def is_multimodal_gen_model(model_architectures: List[str]):
    return False


def is_image_gen_model(model_architectures: List[str]):
    return False


def is_audio_model(model_architectures: List[str]):
    return False


784
785
def is_encoder_decoder_model(model_architectures: List[str]):
    return "MllamaForConditionalGeneration" in model_architectures
786
787


788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
    """Check if chunked prefill is supported for a MultiModal model."""
    unsupported = [
        "Grok1VForCausalLM",
        "Grok1AForCausalLM",
        "LlavaLlamaForCausalLM",
        "MllamaForConditionalGeneration",
        "CLIPModel",
    ]
    if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
        return False
    else:
        return True


803
804
805
806
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0
tarinkk's avatar
tarinkk committed
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835


def is_hybrid_model(
    model_architectures: List[str],
    hybrid_kvcache_ratio: Optional[float],
    context_length: Optional[int],
    attention_chunk_size: Optional[int],
):
    if hybrid_kvcache_ratio is None:
        return None
    elif (
        hybrid_kvcache_ratio > 0
        and model_architectures[0] == "Llama4ForConditionalGeneration"
        and context_length > attention_chunk_size
    ):
        return hybrid_kvcache_ratio
    else:
        return None


def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
    if "Llama4ForConditionalGeneration" in model_architectures:
        swa_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
        ]
        full_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
        ]
    else:
Hanming Lu's avatar
Hanming Lu committed
836
837
        swa_attention_layer_ids = None
        full_attention_layer_ids = None
tarinkk's avatar
tarinkk committed
838
    return swa_attention_layer_ids, full_attention_layer_ids