model_config.py 33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import logging
17
import math
18
import os
19
from enum import Enum, IntEnum, auto
20
from typing import List, Optional, Set, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21

22
import torch
Qubitium's avatar
Qubitium committed
23
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
26
27
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
Atream's avatar
Atream committed
28
    get_generation_config,
29
    get_hf_text_config,
30
    get_sparse_attention_config,
31
)
32
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
33
from sglang.srt.server_args import ServerArgs
34
from sglang.srt.utils import get_bool_env_var, is_hip
35
from sglang.utils import is_in_ci
36

37
38
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
41
42
43
44
class AttentionArch(IntEnum):
    MLA = auto()
    MHA = auto()


45
46
47
48
49
50
class ModelImpl(str, Enum):
    AUTO = "auto"
    SGLANG = "sglang"
    TRANSFORMERS = "transformers"


Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
class ModelConfig:
    def __init__(
        self,
54
        model_path: str,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
57
        context_length: Optional[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
        model_override_args: str = "{}",
Chayenne's avatar
Chayenne committed
59
        is_embedding: Optional[bool] = None,
60
        enable_multimodal: Optional[bool] = None,
61
62
        dtype: str = "auto",
        quantization: Optional[str] = None,
63
        override_config_file: Optional[str] = None,
64
        is_draft_model: bool = False,
tarinkk's avatar
tarinkk committed
65
        hybrid_kvcache_ratio: Optional[float] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
66
        model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
67
68
69
70
71
72
        tp_rank: Optional[int] = None,
        remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
        remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
        remote_instance_weight_loader_send_weights_group_ports: Optional[
            List[int]
        ] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
73
    ) -> None:
74
        # Parse args
75
76
77
        self.model_path = model_path
        self.revision = revision
        self.quantization = quantization
Lianmin Zheng's avatar
Lianmin Zheng committed
78
        self.model_impl = model_impl
79
80
81
82
83
84
85
86
87
88
        self.tp_rank = tp_rank
        self.remote_instance_weight_loader_seed_instance_ip = (
            remote_instance_weight_loader_seed_instance_ip
        )
        self.remote_instance_weight_loader_seed_instance_service_port = (
            remote_instance_weight_loader_seed_instance_service_port
        )
        self.remote_instance_weight_loader_send_weights_group_ports = (
            remote_instance_weight_loader_send_weights_group_ports
        )
89

90
        self.maybe_pull_model_tokenizer_from_remote()
91
        self.model_override_args = json.loads(model_override_args)
92
93
94
95
        kwargs = {}
        if override_config_file and override_config_file.strip():
            kwargs["_configuration_file"] = override_config_file.strip()

96
        self.hf_config = get_config(
97
            self.model_path,
98
99
100
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
101
            **kwargs,
102
        )
103

Atream's avatar
Atream committed
104
105
106
107
108
109
110
        self.hf_generation_config = get_generation_config(
            self.model_path,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )

Qubitium's avatar
Qubitium committed
111
        self.hf_text_config = get_hf_text_config(self.hf_config)
Chang Su's avatar
Chang Su committed
112
113
114
        self.attention_chunk_size = getattr(
            self.hf_text_config, "attention_chunk_size", None
        )
tarinkk's avatar
tarinkk committed
115
116
117
118
119
120
121
122
123
124
125
126
        self.is_hybrid = is_hybrid_model(
            self.hf_config.architectures,
            hybrid_kvcache_ratio=hybrid_kvcache_ratio,
            context_length=context_length,
            attention_chunk_size=self.attention_chunk_size,
        )
        if self.is_hybrid is not None:
            self.swa_attention_layer_ids, self.full_attention_layer_ids = (
                get_hybrid_layer_ids(
                    self.hf_config.architectures, self.hf_text_config.num_hidden_layers
                )
            )
127

128
        if enable_multimodal is None:
129
130
131
            mm_disabled_models = [
                "Gemma3ForConditionalGeneration",
                "Llama4ForConditionalGeneration",
Ke Bao's avatar
Ke Bao committed
132
                "Step3VLForConditionalGeneration",
133
134
            ]
            if self.hf_config.architectures[0] in mm_disabled_models:
135
                enable_multimodal = False
136
                logger.info(
137
                    f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
138
                )
139
140
141
            else:
                enable_multimodal = True

142
143
144
145
146
147
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
        ):
            self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"

Yuxuan Zhang's avatar
Yuxuan Zhang committed
148
149
150
        if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
            self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"

151
152
153
154
155
156
157
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
        ):
            self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
            self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers

158
159
        if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
            self.hf_config.architectures[0] = "MiMoMTP"
strgrb's avatar
strgrb committed
160
161
162
163
164
        if is_draft_model and self.hf_config.architectures[0] in [
            "BailingMoeV2ForCausalLM",
            "BailingMoeForCausalLM",
        ]:
            self.hf_config.architectures[0] = "BailingMoeForCausalLMNextN"
165
166
167
168
169
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
        ):
            self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
170

Yi Zhang's avatar
Yi Zhang committed
171
172
173
        if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
            self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"

174
        # Check model type
Chayenne's avatar
Chayenne committed
175
176
177
        self.is_generation = is_generation_model(
            self.hf_config.architectures, is_embedding
        )
178
179
180
181
182
183
184
185
186
187
188
189
        self.is_multimodal = enable_multimodal and is_multimodal_model(
            self.hf_config.architectures
        )
        self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
            self.hf_config.architectures
        )
        self.is_image_gen = enable_multimodal and is_image_gen_model(
            self.hf_config.architectures
        )
        self.is_audio_model = enable_multimodal and is_audio_model(
            self.hf_config.architectures
        )
190
191
192
193
        self.is_multimodal_chunked_prefill_supported = (
            enable_multimodal
            and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
        )
194
        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
195
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
196
197

        # Derive context length
198
        derived_context_len = get_context_length(self.hf_text_config)
199
        if context_length is not None:
200
            if context_length > derived_context_len:
201
202
203
204
205
206
207
208
                reason = "Target model's" if is_draft_model else "User-specified"
                msg = (
                    f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                    f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
                )
                if (
                    get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
                    or is_in_ci()  # FIXME: fix this special case
209
                ):
210
                    logger.warning(msg)
211
212
213
                    self.context_len = context_length
                else:
                    raise ValueError(
214
                        f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
215
216
217
                    )
            else:
                self.context_len = context_length
218
        else:
219
            self.context_len = derived_context_len
Lianmin Zheng's avatar
Lianmin Zheng committed
220

Liangsheng Yin's avatar
Liangsheng Yin committed
221
        # Unify the config keys for hf_text_config
Liangsheng Yin's avatar
Liangsheng Yin committed
222
        self.head_dim = getattr(
Liangsheng Yin's avatar
Liangsheng Yin committed
223
            self.hf_text_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
224
            "head_dim",
Liangsheng Yin's avatar
Liangsheng Yin committed
225
            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
Liangsheng Yin's avatar
Liangsheng Yin committed
226
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
227

228
        # FIXME: temporary special judge for MLA architecture
HandH1998's avatar
HandH1998 committed
229
230
231
        if (
            "DeepseekV2ForCausalLM" in self.hf_config.architectures
            or "DeepseekV3ForCausalLM" in self.hf_config.architectures
232
            or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
233
234
            or "LongcatFlashForCausalLM" in self.hf_config.architectures
            or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
235
            or "DotsVLMForCausalLM" in self.hf_config.architectures
HandH1998's avatar
HandH1998 committed
236
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
237
            self.head_dim = 256
238
239
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
240
            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
241
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
242
243
244
245
246
247
248
249
250
251
252
253
            self.v_head_dim = self.hf_config.v_head_dim

            # Handle rope scaling with yarn
            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            if self.hf_config.rope_scaling:
                mscale_all_dim = self.hf_config.rope_scaling.get(
                    "mscale_all_dim", False
                )
                scaling_factor = self.hf_config.rope_scaling["factor"]
                mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
                self.scaling = self.scaling * mscale * mscale

William's avatar
William committed
254
255
256
257
258
        elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 128
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
259
260
261
        elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
            self.hf_text_config, "use_mla", True
        ):
262
263
264
265
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
266
267
268
269
270
271
272
        elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
            self.v_head_dim = self.hf_text_config.v_head_dim
            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
273
        else:
274
275
276
            if (
                "MistralModel" in self.hf_config.architectures
                or "MixtralForCausalLM" in self.hf_config.architectures
277
                or "MistralForCausalLM" in self.hf_config.architectures
278
279
280
281
282
283
284
285
286
287
288
289
            ):
                if getattr(self, "head_dim", None) is None:
                    self.head_dim = (
                        self.hf_config.hidden_size // self.hf_config.num_attention_heads
                    )
                    # In transformers==4.52.3, the head_dim is null in MistralConfig
                    if (
                        not hasattr(self.hf_text_config, "head_dim")
                        or self.hf_text_config.head_dim is None
                    ):
                        setattr(self.hf_text_config, "head_dim", self.head_dim)

290
            self.attention_arch = AttentionArch.MHA
Liangsheng Yin's avatar
Liangsheng Yin committed
291

Liangsheng Yin's avatar
Liangsheng Yin committed
292
293
294
295
        self.num_attention_heads = self.hf_text_config.num_attention_heads
        self.num_key_value_heads = getattr(
            self.hf_text_config, "num_key_value_heads", None
        )
296
297
298
299
300
301
302

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

303
304
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Liangsheng Yin's avatar
Liangsheng Yin committed
305
306
        self.hidden_size = self.hf_text_config.hidden_size
        self.num_hidden_layers = self.hf_text_config.num_hidden_layers
307
308
309
        self.num_attention_layers = self.num_hidden_layers
        if "LongcatFlashForCausalLM" in self.hf_config.architectures:
            self.num_attention_layers = self.num_hidden_layers * 2
310
311
312
        self.num_nextn_predict_layers = getattr(
            self.hf_text_config, "num_nextn_predict_layers", None
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
313
        self.vocab_size = self.hf_text_config.vocab_size
Qubitium's avatar
Qubitium committed
314

315
        # Verify quantization
316
317
        self._verify_quantization()

318
319
320
        # Verify dual-chunk attention config
        self._verify_dual_chunk_attention_config()

321
        # Cache attributes
322
        self.hf_eos_token_id = self.get_hf_eos_token_id()
323
324

        # multimodal
325
326
327
        self.image_token_id = getattr(
            self.hf_config, "image_token_id", None
        ) or getattr(self.hf_config, "image_token_index", None)
328

329
    @staticmethod
330
331
332
333
334
335
    def from_server_args(
        server_args: ServerArgs,
        model_path: str = None,
        model_revision: str = None,
        **kwargs,
    ):
336
337
338
        return ModelConfig(
            model_path=model_path or server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
339
            revision=model_revision or server_args.revision,
340
341
342
343
344
345
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            enable_multimodal=server_args.enable_multimodal,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
tarinkk's avatar
tarinkk committed
346
            hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
Lianmin Zheng's avatar
Lianmin Zheng committed
347
            model_impl=server_args.model_impl,
348
349
350
            remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
            remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
            remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
351
352
353
            **kwargs,
        )

354
355
356
357
358
359
360
    def get_total_num_attention_heads(self) -> int:
        return self.num_attention_heads

    def get_num_attention_heads(self, tensor_parallel_size) -> int:
        total_num_attention_heads = self.num_attention_heads
        return max(1, total_num_attention_heads // tensor_parallel_size)

Qubitium's avatar
Qubitium committed
361
362
363
364
365
366
367
368
369
    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
370
371
372
373
374
375
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
376
377
378
379
380
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
381
382
383
384
385
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
386
387
388
389
390
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
391
392
393
394
395
396
397
398
399
400
401
402
403
        if self.hf_config.model_type in ["nemotron-nas"]:
            nkvh = {
                self.hf_config.num_attention_heads // block.attention.n_heads_in_group
                for block in self.hf_config.block_configs
                if not block.attention.no_op
            }
            if len(nkvh) == 0:
                raise RuntimeError("Couldn't determine number of kv heads")
            if len(nkvh) > 1:
                raise ValueError(
                    "Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang"
                )
            return next(iter(nkvh))
Qubitium's avatar
Qubitium committed
404
405
406
407
408
409
410
411
412

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
Chang Su's avatar
Chang Su committed
413
414
            # For Step3
            "num_attention_groups",
Qubitium's avatar
Qubitium committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
432
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
433

434
435
436
437
438
439
    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
            # compressed-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
440
        if quant_cfg is None:
441
            # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
442
443
            # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
            # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
444
            # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
445
            is_local = os.path.exists(self.model_path)
446
447
            modelopt_quant_config = {"quant_method": "modelopt"}
            if not is_local:
448
449
450
451
452
453
454
455
456
457
458
459
460
                import huggingface_hub

                try:
                    from huggingface_hub import HfApi

                    hf_api = HfApi()
                    if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
                        quant_cfg = modelopt_quant_config
                except huggingface_hub.errors.OfflineModeIsEnabled:
                    logger.warning(
                        "Offline mode is enabled, skipping hf_quant_config.json check"
                    )
                    pass
461
462

            elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
463
464
465
466
467
468
469
470
471
472
473
                quant_config_file = os.path.join(
                    self.model_path, "hf_quant_config.json"
                )
                with open(quant_config_file) as f:
                    quant_config_dict = json.load(f)
                json_quant_configs = quant_config_dict["quantization"]
                quant_algo = json_quant_configs.get("quant_algo", None)
                if quant_algo == "MIXED_PRECISION":
                    quant_cfg = {"quant_method": "w4afp8"}
                else:
                    quant_cfg = modelopt_quant_config
474
475
476
477
        return quant_cfg

    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _verify_quantization(self) -> None:
478
        supported_quantization = [*QUANTIZATION_METHODS]
479
480
481
482
483
484
485
        rocm_supported_quantization = [
            "awq",
            "gptq",
            "fp8",
            "compressed_tensors",
            "compressed-tensors",
            "fbgemm_fp8",
HandH1998's avatar
HandH1998 committed
486
            "w8a8_fp8",
487
            "petit_nvfp4",
488
489
            "quark",
            "mxfp4",
490
491
492
493
494
495
496
497
498
499
500
501
        ]
        optimized_quantization_methods = [
            "fp8",
            "marlin",
            "modelopt",
            "gptq_marlin_24",
            "gptq_marlin",
            "awq_marlin",
            "fbgemm_fp8",
            "compressed_tensors",
            "compressed-tensors",
            "experts_int8",
502
            "w8a8_int8",
HandH1998's avatar
HandH1998 committed
503
            "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
504
            "moe_wna16",
HandH1998's avatar
HandH1998 committed
505
            "qoq",
506
            "w4afp8",
507
            "petit_nvfp4",
508
        ]
509
        compatible_quantization_methods = {
510
            "modelopt_fp4": ["modelopt"],
511
            "petit_nvfp4": ["modelopt"],
HandH1998's avatar
HandH1998 committed
512
513
            "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
            "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
514
        }
515
516
517
518
519
520
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        quant_cfg = self._parse_quant_hf_config()

521
        if quant_cfg is not None:
522
523
524
            quant_method = quant_cfg.get(
                "quant_method", "" if not self.quantization else self.quantization
            ).lower()
525
526

            # Detect which checkpoint is it
527
            for _, method in QUANTIZATION_METHODS.items():
528
529
530
531
532
533
534
535
536
537
538
539
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization
                )
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break

            # Verify quantization configurations.
            if self.quantization is None:
                self.quantization = quant_method
            elif self.quantization != quant_method:
540
541
542
543
544
545
546
547
548
549
550
                if (
                    self.quantization not in compatible_quantization_methods
                    or quant_method
                    not in compatible_quantization_methods[self.quantization]
                ):
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization})."
                    )
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}."
                )
            if is_hip() and self.quantization not in rocm_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm."
                )
            if self.quantization not in optimized_quantization_methods:
                logger.warning(
                    "%s quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.",
                    self.quantization,
                )

571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    def _verify_dual_chunk_attention_config(self) -> None:
        if hasattr(self.hf_config, "dual_chunk_attention_config"):
            # Try loading the sparse attention config
            sparse_attn_config = get_sparse_attention_config(self.model_path)
            if not sparse_attn_config:
                return
            self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
                sparse_attn_config
            )
            if (
                "sparse_attention_enabled"
                not in self.hf_config.dual_chunk_attention_config
            ):
                self.hf_config.dual_chunk_attention_config[
                    "sparse_attention_enabled"
                ] = True

588
589
    def get_hf_eos_token_id(self) -> Optional[Set[int]]:
        eos_ids = getattr(self.hf_config, "eos_token_id", None)
Minho Ryu's avatar
Minho Ryu committed
590
        if eos_ids is not None:
591
592
            # it can be either int or list of int
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
Atream's avatar
Atream committed
593
594
595
596
597
598
599
600
601
602
603
604
605
        if eos_ids is None:
            eos_ids = set()
        if self.hf_generation_config:
            generation_eos_ids = getattr(
                self.hf_generation_config, "eos_token_id", None
            )
            if generation_eos_ids:
                generation_eos_ids = (
                    {generation_eos_ids}
                    if isinstance(generation_eos_ids, int)
                    else set(generation_eos_ids)
                )
                eos_ids = eos_ids | generation_eos_ids
606
607
        return eos_ids

608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
    def maybe_pull_model_tokenizer_from_remote(self) -> None:
        """
        Pull the model config files to a temporary
        directory in case of remote.

        Args:
            model: The model name or path.

        """
        from sglang.srt.connector import create_remote_connector
        from sglang.srt.utils import is_remote_url

        if is_remote_url(self.model_path):
            logger.info("Pulling model configs from remote...")
            # BaseConnector implements __del__() to clean up the local dir.
            # Since config files need to exist all the time, so we DO NOT use
            # with statement to avoid closing the client.
            client = create_remote_connector(self.model_path)
            if is_remote_url(self.model_path):
                client.pull_files(allow_pattern=["*config.json"])
                self.model_weights = self.model_path
                self.model_path = client.get_local_dir()

Qubitium's avatar
Qubitium committed
631

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: Union[str, torch.dtype],
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
HandH1998's avatar
HandH1998 committed
650
651
    if isinstance(config_dtype, str):
        config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
652
653
654
655
656
657
658
    if config_dtype is None:
        config_dtype = torch.float32

    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
659
660
661
662
663
                if config.model_type.startswith("gemma"):
                    if config.model_type == "gemma":
                        gemma_version = ""
                    else:
                        gemma_version = config.model_type[5]
664
                    logger.info(
665
                        f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16."
                    )
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
        else:
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
            pass
        else:
            # Casting between float16 and bfloat16 is allowed with a warning.
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

    return torch_dtype


702
703
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
Mick's avatar
Mick committed
704
    # 1. Check the model architecture
705
706
707
708
709
710
711
    # 2. check the `is_embedding` server args

    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
RangiLyu's avatar
RangiLyu committed
712
        or "InternLM2ForRewardModel" in model_architectures
713
        or "Qwen2ForRewardModel" in model_architectures
714
        or "Qwen2ForSequenceClassification" in model_architectures
715
        or "Qwen3ForSequenceClassification" in model_architectures
uylnap's avatar
uylnap committed
716
        or "CLIPModel" in model_architectures
woodx's avatar
woodx committed
717
718
719
720
721
        or "BertModel" in model_architectures
        or "Contriever" in model_architectures
        or "BertForSequenceClassification" in model_architectures
        or "XLMRobertaModel" in model_architectures
        or "XLMRobertaForSequenceClassification" in model_architectures
722
723
724
725
726
727
    ):
        return False
    else:
        return not is_embedding


Mick's avatar
Mick committed
728
multimodal_model_archs = [
729
    "CLIPModel",
730
    "DeepseekVL2ForCausalLM",
731
    "Gemma3ForConditionalGeneration",
732
    "Gemma3nForConditionalGeneration",
733
734
    "Glm4vForConditionalGeneration",
    "Glm4vMoeForConditionalGeneration",
Mick's avatar
Mick committed
735
736
    "Grok1VForCausalLM",
    "Grok1AForCausalLM",
Mick's avatar
Mick committed
737
    "LlavaLlamaForCausalLM",
Mick's avatar
Mick committed
738
    "Llama4ForConditionalGeneration",
Mick's avatar
Mick committed
739
740
    "LlavaMistralForCausalLM",
    "LlavaQwenForCausalLM",
Kiv Chen's avatar
Kiv Chen committed
741
    "LlavaForConditionalGeneration",
Mick's avatar
Mick committed
742
743
744
    "LlavaVidForCausalLM",
    "MiniCPMO",
    "MiniCPMV",
Kiv Chen's avatar
Kiv Chen committed
745
    "Mistral3ForConditionalGeneration",
Mick's avatar
Mick committed
746
    "MultiModalityCausalLM",
Mick's avatar
Mick committed
747
    "MllamaForConditionalGeneration",
Leng Yue's avatar
Leng Yue committed
748
    "Qwen2AudioForConditionalGeneration",
Mick's avatar
Mick committed
749
750
    "Qwen2VLForConditionalGeneration",
    "Qwen2_5_VLForConditionalGeneration",
751
    "KimiVLForConditionalGeneration",
xm:D's avatar
xm:D committed
752
    "InternVLChatModel",
RunningLeon's avatar
RunningLeon committed
753
    "InternS1ForConditionalGeneration",
754
    "Phi4MMForCausalLM",
Zijian's avatar
Zijian committed
755
    "VILAForConditionalGeneration",
Chang Su's avatar
Chang Su committed
756
    "Step3VLForConditionalGeneration",
757
    "DotsVLMForCausalLM",
Mick's avatar
Mick committed
758
759
760
]


761
def is_multimodal_model(model_architectures: List[str]):
Mick's avatar
Mick committed
762
763
764
    if any(
        multi_model_arch in model_architectures
        for multi_model_arch in multimodal_model_archs
765
766
767
768
769
770
    ):
        return True
    else:
        return False


771
772
773
774
775
776
777
778
779
780
781
782
def is_multimodal_gen_model(model_architectures: List[str]):
    return False


def is_image_gen_model(model_architectures: List[str]):
    return False


def is_audio_model(model_architectures: List[str]):
    return False


783
784
def is_encoder_decoder_model(model_architectures: List[str]):
    return "MllamaForConditionalGeneration" in model_architectures
785
786


787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
    """Check if chunked prefill is supported for a MultiModal model."""
    unsupported = [
        "Grok1VForCausalLM",
        "Grok1AForCausalLM",
        "LlavaLlamaForCausalLM",
        "MllamaForConditionalGeneration",
        "CLIPModel",
    ]
    if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
        return False
    else:
        return True


802
803
804
805
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0
tarinkk's avatar
tarinkk committed
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834


def is_hybrid_model(
    model_architectures: List[str],
    hybrid_kvcache_ratio: Optional[float],
    context_length: Optional[int],
    attention_chunk_size: Optional[int],
):
    if hybrid_kvcache_ratio is None:
        return None
    elif (
        hybrid_kvcache_ratio > 0
        and model_architectures[0] == "Llama4ForConditionalGeneration"
        and context_length > attention_chunk_size
    ):
        return hybrid_kvcache_ratio
    else:
        return None


def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
    if "Llama4ForConditionalGeneration" in model_architectures:
        swa_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
        ]
        full_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
        ]
    else:
Hanming Lu's avatar
Hanming Lu committed
835
836
        swa_attention_layer_ids = None
        full_attention_layer_ids = None
tarinkk's avatar
tarinkk committed
837
    return swa_attention_layer_ids, full_attention_layer_ids