model_config.py 29.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import logging
17
import math
18
import os
19
from enum import Enum, IntEnum, auto
20
from typing import List, Optional, Set, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21

22
import torch
Qubitium's avatar
Qubitium committed
23
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
26
27
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
Atream's avatar
Atream committed
28
    get_generation_config,
29
    get_hf_text_config,
30
    get_sparse_attention_config,
31
)
32
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
33
from sglang.srt.server_args import ServerArgs
34
from sglang.srt.utils import get_bool_env_var, is_hip
35

36
37
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
38

39
40
41
42
43
class AttentionArch(IntEnum):
    MLA = auto()
    MHA = auto()


44
45
46
47
48
49
class ModelImpl(str, Enum):
    AUTO = "auto"
    SGLANG = "sglang"
    TRANSFORMERS = "transformers"


Lianmin Zheng's avatar
Lianmin Zheng committed
50
51
52
class ModelConfig:
    def __init__(
        self,
53
        model_path: str,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
56
        context_length: Optional[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
57
        model_override_args: str = "{}",
Chayenne's avatar
Chayenne committed
58
        is_embedding: Optional[bool] = None,
59
        enable_multimodal: Optional[bool] = None,
60
61
        dtype: str = "auto",
        quantization: Optional[str] = None,
62
        override_config_file: Optional[str] = None,
63
        is_draft_model: bool = False,
tarinkk's avatar
tarinkk committed
64
        hybrid_kvcache_ratio: Optional[float] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
65
        model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
Lianmin Zheng's avatar
Lianmin Zheng committed
66
    ) -> None:
67
        # Parse args
68
69
70
        self.model_path = model_path
        self.revision = revision
        self.quantization = quantization
Lianmin Zheng's avatar
Lianmin Zheng committed
71
        self.model_impl = model_impl
72

73
        self.maybe_pull_model_tokenizer_from_remote()
74
        self.model_override_args = json.loads(model_override_args)
75
76
77
78
        kwargs = {}
        if override_config_file and override_config_file.strip():
            kwargs["_configuration_file"] = override_config_file.strip()

79
        self.hf_config = get_config(
80
            self.model_path,
81
82
83
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
84
            **kwargs,
85
        )
86

Atream's avatar
Atream committed
87
88
89
90
91
92
93
        self.hf_generation_config = get_generation_config(
            self.model_path,
            trust_remote_code=trust_remote_code,
            revision=revision,
            **kwargs,
        )

Qubitium's avatar
Qubitium committed
94
        self.hf_text_config = get_hf_text_config(self.hf_config)
Chang Su's avatar
Chang Su committed
95
96
97
        self.attention_chunk_size = getattr(
            self.hf_text_config, "attention_chunk_size", None
        )
tarinkk's avatar
tarinkk committed
98
99
100
101
102
103
104
105
106
107
108
109
        self.is_hybrid = is_hybrid_model(
            self.hf_config.architectures,
            hybrid_kvcache_ratio=hybrid_kvcache_ratio,
            context_length=context_length,
            attention_chunk_size=self.attention_chunk_size,
        )
        if self.is_hybrid is not None:
            self.swa_attention_layer_ids, self.full_attention_layer_ids = (
                get_hybrid_layer_ids(
                    self.hf_config.architectures, self.hf_text_config.num_hidden_layers
                )
            )
110

111
        if enable_multimodal is None:
112
113
114
            mm_disabled_models = [
                "Gemma3ForConditionalGeneration",
                "Llama4ForConditionalGeneration",
Ke Bao's avatar
Ke Bao committed
115
                "Step3VLForConditionalGeneration",
116
117
            ]
            if self.hf_config.architectures[0] in mm_disabled_models:
118
                enable_multimodal = False
119
                logger.info(
120
                    f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
121
                )
122
123
124
            else:
                enable_multimodal = True

125
126
127
128
129
130
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
        ):
            self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"

Yuxuan Zhang's avatar
Yuxuan Zhang committed
131
132
133
        if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
            self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"

134
135
        if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
            self.hf_config.architectures[0] = "MiMoMTP"
136
137
138
139
140
        if (
            is_draft_model
            and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
        ):
            self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
141

142
        # Check model type
Chayenne's avatar
Chayenne committed
143
144
145
        self.is_generation = is_generation_model(
            self.hf_config.architectures, is_embedding
        )
146
147
148
149
150
151
152
153
154
155
156
157
        self.is_multimodal = enable_multimodal and is_multimodal_model(
            self.hf_config.architectures
        )
        self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
            self.hf_config.architectures
        )
        self.is_image_gen = enable_multimodal and is_image_gen_model(
            self.hf_config.architectures
        )
        self.is_audio_model = enable_multimodal and is_audio_model(
            self.hf_config.architectures
        )
158
159
160
161
        self.is_multimodal_chunked_prefill_supported = (
            enable_multimodal
            and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
        )
162
        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
163
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
164
165

        # Derive context length
166
        derived_context_len = get_context_length(self.hf_text_config)
167
        if context_length is not None:
168
            if context_length > derived_context_len:
169
                if get_bool_env_var(
170
                    "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
171
                ):
172
173
174
175
176
177
178
179
180
181
182
183
184
                    logger.warning(
                        f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                        f"This may lead to incorrect model outputs or CUDA errors."
                    )
                    self.context_len = context_length
                else:
                    raise ValueError(
                        f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                        f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
                        f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
                    )
            else:
                self.context_len = context_length
185
        else:
186
            self.context_len = derived_context_len
Lianmin Zheng's avatar
Lianmin Zheng committed
187

Liangsheng Yin's avatar
Liangsheng Yin committed
188
        # Unify the config keys for hf_text_config
Liangsheng Yin's avatar
Liangsheng Yin committed
189
        self.head_dim = getattr(
Liangsheng Yin's avatar
Liangsheng Yin committed
190
            self.hf_text_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
191
            "head_dim",
Liangsheng Yin's avatar
Liangsheng Yin committed
192
            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
Liangsheng Yin's avatar
Liangsheng Yin committed
193
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
194

195
        # FIXME: temporary special judge for MLA architecture
HandH1998's avatar
HandH1998 committed
196
197
198
        if (
            "DeepseekV2ForCausalLM" in self.hf_config.architectures
            or "DeepseekV3ForCausalLM" in self.hf_config.architectures
199
            or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
HandH1998's avatar
HandH1998 committed
200
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
201
            self.head_dim = 256
202
203
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
204
            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
205
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
206
207
208
209
210
211
212
213
214
215
216
217
            self.v_head_dim = self.hf_config.v_head_dim

            # Handle rope scaling with yarn
            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            if self.hf_config.rope_scaling:
                mscale_all_dim = self.hf_config.rope_scaling.get(
                    "mscale_all_dim", False
                )
                scaling_factor = self.hf_config.rope_scaling["factor"]
                mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
                self.scaling = self.scaling * mscale * mscale

William's avatar
William committed
218
219
220
221
222
        elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 128
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
223
224
225
        elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
            self.hf_text_config, "use_mla", True
        ):
226
227
228
229
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
230
231
232
233
234
235
236
        elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
            self.v_head_dim = self.hf_text_config.v_head_dim
            self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
237
        else:
238
239
240
            if (
                "MistralModel" in self.hf_config.architectures
                or "MixtralForCausalLM" in self.hf_config.architectures
241
                or "MistralForCausalLM" in self.hf_config.architectures
242
243
244
245
246
247
248
249
250
251
252
253
            ):
                if getattr(self, "head_dim", None) is None:
                    self.head_dim = (
                        self.hf_config.hidden_size // self.hf_config.num_attention_heads
                    )
                    # In transformers==4.52.3, the head_dim is null in MistralConfig
                    if (
                        not hasattr(self.hf_text_config, "head_dim")
                        or self.hf_text_config.head_dim is None
                    ):
                        setattr(self.hf_text_config, "head_dim", self.head_dim)

254
            self.attention_arch = AttentionArch.MHA
Liangsheng Yin's avatar
Liangsheng Yin committed
255

Liangsheng Yin's avatar
Liangsheng Yin committed
256
257
258
259
        self.num_attention_heads = self.hf_text_config.num_attention_heads
        self.num_key_value_heads = getattr(
            self.hf_text_config, "num_key_value_heads", None
        )
260
261
262
263
264
265
266

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

267
268
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Liangsheng Yin's avatar
Liangsheng Yin committed
269
270
        self.hidden_size = self.hf_text_config.hidden_size
        self.num_hidden_layers = self.hf_text_config.num_hidden_layers
271
272
273
        self.num_nextn_predict_layers = getattr(
            self.hf_text_config, "num_nextn_predict_layers", None
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
274
        self.vocab_size = self.hf_text_config.vocab_size
Qubitium's avatar
Qubitium committed
275

276
        # Verify quantization
277
278
        self._verify_quantization()

279
280
281
        # Verify dual-chunk attention config
        self._verify_dual_chunk_attention_config()

282
        # Cache attributes
283
        self.hf_eos_token_id = self.get_hf_eos_token_id()
284
285

        # multimodal
286
287
288
        self.image_token_id = getattr(
            self.hf_config, "image_token_id", None
        ) or getattr(self.hf_config, "image_token_index", None)
289

290
291
292
293
294
295
296
297
298
299
300
301
    @staticmethod
    def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
        return ModelConfig(
            model_path=model_path or server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            enable_multimodal=server_args.enable_multimodal,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
tarinkk's avatar
tarinkk committed
302
            hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
Lianmin Zheng's avatar
Lianmin Zheng committed
303
            model_impl=server_args.model_impl,
304
305
306
            **kwargs,
        )

307
308
309
310
311
312
313
    def get_total_num_attention_heads(self) -> int:
        return self.num_attention_heads

    def get_num_attention_heads(self, tensor_parallel_size) -> int:
        total_num_attention_heads = self.num_attention_heads
        return max(1, total_num_attention_heads // tensor_parallel_size)

Qubitium's avatar
Qubitium committed
314
315
316
317
318
319
320
321
322
    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
323
324
325
326
327
328
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
329
330
331
332
333
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
334
335
336
337
338
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
339
340
341
342
343
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
Qubitium's avatar
Qubitium committed
344
345
346
347
348
349
350
351
352

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
Chang Su's avatar
Chang Su committed
353
354
            # For Step3
            "num_attention_groups",
Qubitium's avatar
Qubitium committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
372
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
373

374
375
376
377
378
379
    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
            # compressed-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
380
381
382
383
        if quant_cfg is None:
            # check if is modelopt model -- modelopt doesn't have corresponding field
            # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
            # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
384
            is_local = os.path.exists(self.model_path)
385
386
387
388
389
390
391
392
            modelopt_quant_config = {"quant_method": "modelopt"}
            if not is_local:
                from huggingface_hub import HfApi

                hf_api = HfApi()
                if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
                    quant_cfg = modelopt_quant_config
            elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
393
394
395
396
397
398
399
400
401
402
403
                quant_config_file = os.path.join(
                    self.model_path, "hf_quant_config.json"
                )
                with open(quant_config_file) as f:
                    quant_config_dict = json.load(f)
                json_quant_configs = quant_config_dict["quantization"]
                quant_algo = json_quant_configs.get("quant_algo", None)
                if quant_algo == "MIXED_PRECISION":
                    quant_cfg = {"quant_method": "w4afp8"}
                else:
                    quant_cfg = modelopt_quant_config
404
405
406
407
        return quant_cfg

    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _verify_quantization(self) -> None:
408
        supported_quantization = [*QUANTIZATION_METHODS]
409
410
411
412
413
414
415
        rocm_supported_quantization = [
            "awq",
            "gptq",
            "fp8",
            "compressed_tensors",
            "compressed-tensors",
            "fbgemm_fp8",
HandH1998's avatar
HandH1998 committed
416
            "w8a8_fp8",
417
            "petit_nvfp4",
418
419
            "quark",
            "mxfp4",
420
421
422
423
424
425
426
427
428
429
430
431
        ]
        optimized_quantization_methods = [
            "fp8",
            "marlin",
            "modelopt",
            "gptq_marlin_24",
            "gptq_marlin",
            "awq_marlin",
            "fbgemm_fp8",
            "compressed_tensors",
            "compressed-tensors",
            "experts_int8",
432
            "w8a8_int8",
HandH1998's avatar
HandH1998 committed
433
            "w8a8_fp8",
AniZpZ's avatar
AniZpZ committed
434
            "moe_wna16",
HandH1998's avatar
HandH1998 committed
435
            "qoq",
436
            "w4afp8",
437
            "petit_nvfp4",
438
        ]
439
        compatible_quantization_methods = {
440
            "modelopt_fp4": ["modelopt"],
441
            "petit_nvfp4": ["modelopt"],
HandH1998's avatar
HandH1998 committed
442
443
            "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
            "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
444
        }
445
446
447
448
449
450
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        quant_cfg = self._parse_quant_hf_config()

451
        if quant_cfg is not None:
452
453
454
            quant_method = quant_cfg.get(
                "quant_method", "" if not self.quantization else self.quantization
            ).lower()
455
456

            # Detect which checkpoint is it
457
            for _, method in QUANTIZATION_METHODS.items():
458
459
460
461
462
463
464
465
466
467
468
469
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization
                )
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break

            # Verify quantization configurations.
            if self.quantization is None:
                self.quantization = quant_method
            elif self.quantization != quant_method:
470
471
472
473
474
475
476
477
478
479
480
                if (
                    self.quantization not in compatible_quantization_methods
                    or quant_method
                    not in compatible_quantization_methods[self.quantization]
                ):
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization})."
                    )
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}."
                )
            if is_hip() and self.quantization not in rocm_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm."
                )
            if self.quantization not in optimized_quantization_methods:
                logger.warning(
                    "%s quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.",
                    self.quantization,
                )

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
    def _verify_dual_chunk_attention_config(self) -> None:
        if hasattr(self.hf_config, "dual_chunk_attention_config"):
            # Try loading the sparse attention config
            sparse_attn_config = get_sparse_attention_config(self.model_path)
            if not sparse_attn_config:
                return
            self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
                sparse_attn_config
            )
            if (
                "sparse_attention_enabled"
                not in self.hf_config.dual_chunk_attention_config
            ):
                self.hf_config.dual_chunk_attention_config[
                    "sparse_attention_enabled"
                ] = True

518
519
    def get_hf_eos_token_id(self) -> Optional[Set[int]]:
        eos_ids = getattr(self.hf_config, "eos_token_id", None)
Minho Ryu's avatar
Minho Ryu committed
520
        if eos_ids is not None:
521
522
            # it can be either int or list of int
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
Atream's avatar
Atream committed
523
524
525
526
527
528
529
530
531
532
533
534
535
        if eos_ids is None:
            eos_ids = set()
        if self.hf_generation_config:
            generation_eos_ids = getattr(
                self.hf_generation_config, "eos_token_id", None
            )
            if generation_eos_ids:
                generation_eos_ids = (
                    {generation_eos_ids}
                    if isinstance(generation_eos_ids, int)
                    else set(generation_eos_ids)
                )
                eos_ids = eos_ids | generation_eos_ids
536
537
        return eos_ids

538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    def maybe_pull_model_tokenizer_from_remote(self) -> None:
        """
        Pull the model config files to a temporary
        directory in case of remote.

        Args:
            model: The model name or path.

        """
        from sglang.srt.connector import create_remote_connector
        from sglang.srt.utils import is_remote_url

        if is_remote_url(self.model_path):
            logger.info("Pulling model configs from remote...")
            # BaseConnector implements __del__() to clean up the local dir.
            # Since config files need to exist all the time, so we DO NOT use
            # with statement to avoid closing the client.
            client = create_remote_connector(self.model_path)
            if is_remote_url(self.model_path):
                client.pull_files(allow_pattern=["*config.json"])
                self.model_weights = self.model_path
                self.model_path = client.get_local_dir()

Qubitium's avatar
Qubitium committed
561

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: Union[str, torch.dtype],
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
HandH1998's avatar
HandH1998 committed
580
581
    if isinstance(config_dtype, str):
        config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
582
583
584
585
586
587
588
    if config_dtype is None:
        config_dtype = torch.float32

    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
589
590
591
592
593
                if config.model_type.startswith("gemma"):
                    if config.model_type == "gemma":
                        gemma_version = ""
                    else:
                        gemma_version = config.model_type[5]
594
                    logger.info(
595
                        f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16."
                    )
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
        else:
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
            pass
        else:
            # Casting between float16 and bfloat16 is allowed with a warning.
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

    return torch_dtype


632
633
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
Mick's avatar
Mick committed
634
    # 1. Check the model architecture
635
636
637
638
639
640
641
    # 2. check the `is_embedding` server args

    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
RangiLyu's avatar
RangiLyu committed
642
        or "InternLM2ForRewardModel" in model_architectures
643
        or "Qwen2ForRewardModel" in model_architectures
644
        or "Qwen2ForSequenceClassification" in model_architectures
uylnap's avatar
uylnap committed
645
        or "CLIPModel" in model_architectures
woodx's avatar
woodx committed
646
647
648
649
650
        or "BertModel" in model_architectures
        or "Contriever" in model_architectures
        or "BertForSequenceClassification" in model_architectures
        or "XLMRobertaModel" in model_architectures
        or "XLMRobertaForSequenceClassification" in model_architectures
651
652
653
654
655
656
    ):
        return False
    else:
        return not is_embedding


Mick's avatar
Mick committed
657
multimodal_model_archs = [
658
    "CLIPModel",
659
    "DeepseekVL2ForCausalLM",
660
    "Gemma3ForConditionalGeneration",
661
    "Gemma3nForConditionalGeneration",
662
663
    "Glm4vForConditionalGeneration",
    "Glm4vMoeForConditionalGeneration",
Mick's avatar
Mick committed
664
665
    "Grok1VForCausalLM",
    "Grok1AForCausalLM",
Mick's avatar
Mick committed
666
    "LlavaLlamaForCausalLM",
Mick's avatar
Mick committed
667
    "Llama4ForConditionalGeneration",
Mick's avatar
Mick committed
668
669
    "LlavaMistralForCausalLM",
    "LlavaQwenForCausalLM",
Kiv Chen's avatar
Kiv Chen committed
670
    "LlavaForConditionalGeneration",
Mick's avatar
Mick committed
671
672
673
    "LlavaVidForCausalLM",
    "MiniCPMO",
    "MiniCPMV",
Kiv Chen's avatar
Kiv Chen committed
674
    "Mistral3ForConditionalGeneration",
Mick's avatar
Mick committed
675
    "MultiModalityCausalLM",
Mick's avatar
Mick committed
676
    "MllamaForConditionalGeneration",
Leng Yue's avatar
Leng Yue committed
677
    "Qwen2AudioForConditionalGeneration",
Mick's avatar
Mick committed
678
679
    "Qwen2VLForConditionalGeneration",
    "Qwen2_5_VLForConditionalGeneration",
680
    "KimiVLForConditionalGeneration",
xm:D's avatar
xm:D committed
681
    "InternVLChatModel",
RunningLeon's avatar
RunningLeon committed
682
    "InternS1ForConditionalGeneration",
683
    "Phi4MMForCausalLM",
Zijian's avatar
Zijian committed
684
    "VILAForConditionalGeneration",
Chang Su's avatar
Chang Su committed
685
    "Step3VLForConditionalGeneration",
Mick's avatar
Mick committed
686
687
688
]


689
def is_multimodal_model(model_architectures: List[str]):
Mick's avatar
Mick committed
690
691
692
    if any(
        multi_model_arch in model_architectures
        for multi_model_arch in multimodal_model_archs
693
694
695
696
697
698
    ):
        return True
    else:
        return False


699
700
701
702
703
704
705
706
707
708
709
710
def is_multimodal_gen_model(model_architectures: List[str]):
    return False


def is_image_gen_model(model_architectures: List[str]):
    return False


def is_audio_model(model_architectures: List[str]):
    return False


711
712
def is_encoder_decoder_model(model_architectures: List[str]):
    return "MllamaForConditionalGeneration" in model_architectures
713
714


715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
    """Check if chunked prefill is supported for a MultiModal model."""
    unsupported = [
        "Grok1VForCausalLM",
        "Grok1AForCausalLM",
        "LlavaLlamaForCausalLM",
        "MllamaForConditionalGeneration",
        "CLIPModel",
    ]
    if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
        return False
    else:
        return True


730
731
732
733
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0
tarinkk's avatar
tarinkk committed
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762


def is_hybrid_model(
    model_architectures: List[str],
    hybrid_kvcache_ratio: Optional[float],
    context_length: Optional[int],
    attention_chunk_size: Optional[int],
):
    if hybrid_kvcache_ratio is None:
        return None
    elif (
        hybrid_kvcache_ratio > 0
        and model_architectures[0] == "Llama4ForConditionalGeneration"
        and context_length > attention_chunk_size
    ):
        return hybrid_kvcache_ratio
    else:
        return None


def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
    if "Llama4ForConditionalGeneration" in model_architectures:
        swa_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
        ]
        full_attention_layer_ids = [
            i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
        ]
    else:
Hanming Lu's avatar
Hanming Lu committed
763
764
        swa_attention_layer_ids = None
        full_attention_layer_ids = None
tarinkk's avatar
tarinkk committed
765
    return swa_attention_layer_ids, full_attention_layer_ids