model_config.py 19.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import logging
17
import math
18
from enum import IntEnum, auto
19
from typing import List, Optional, Set, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
20

21
import torch
Qubitium's avatar
Qubitium committed
22
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
23

24
from sglang.srt.hf_transformers_utils import get_config, get_context_length
25
26
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.utils import get_bool_env_var, is_hip
27

28
29
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
30

31
32
33
34
35
class AttentionArch(IntEnum):
    MLA = auto()
    MHA = auto()


Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
class ModelConfig:
    def __init__(
        self,
39
        model_path: str,
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
42
        context_length: Optional[int] = None,
43
        model_override_args: Optional[str] = None,
Chayenne's avatar
Chayenne committed
44
        is_embedding: Optional[bool] = None,
45
46
        dtype: str = "auto",
        quantization: Optional[str] = None,
47
        override_config_file: Optional[str] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    ) -> None:
49
50
51
        self.model_path = model_path
        self.revision = revision
        self.quantization = quantization
52

53
        # Parse args
54
        self.maybe_pull_model_tokenizer_from_remote()
55
        self.model_override_args = json.loads(model_override_args)
56
57
58
59
        kwargs = {}
        if override_config_file and override_config_file.strip():
            kwargs["_configuration_file"] = override_config_file.strip()

60
        self.hf_config = get_config(
61
            self.model_path,
62
63
64
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
65
            **kwargs,
66
        )
Qubitium's avatar
Qubitium committed
67
        self.hf_text_config = get_hf_text_config(self.hf_config)
68
69

        # Check model type
Chayenne's avatar
Chayenne committed
70
71
72
        self.is_generation = is_generation_model(
            self.hf_config.architectures, is_embedding
        )
73
        self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
74
75
76
        self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
        self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
        self.is_audio_model = is_audio_model(self.hf_config.architectures)
77
        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
78
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
79
80

        # Derive context length
81
        derived_context_len = get_context_length(self.hf_text_config)
82
        if context_length is not None:
83
            if context_length > derived_context_len:
84
                if get_bool_env_var(
85
                    "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
86
                ):
87
88
89
90
91
92
93
94
95
96
97
98
99
                    logger.warning(
                        f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                        f"This may lead to incorrect model outputs or CUDA errors."
                    )
                    self.context_len = context_length
                else:
                    raise ValueError(
                        f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                        f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
                        f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
                    )
            else:
                self.context_len = context_length
100
        else:
101
            self.context_len = derived_context_len
Lianmin Zheng's avatar
Lianmin Zheng committed
102

Liangsheng Yin's avatar
Liangsheng Yin committed
103
        # Unify the config keys for hf_text_config
Liangsheng Yin's avatar
Liangsheng Yin committed
104
        self.head_dim = getattr(
Liangsheng Yin's avatar
Liangsheng Yin committed
105
            self.hf_text_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
106
            "head_dim",
Liangsheng Yin's avatar
Liangsheng Yin committed
107
            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
Liangsheng Yin's avatar
Liangsheng Yin committed
108
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
109

110
        # FIXME: temporary special judge for MLA architecture
HandH1998's avatar
HandH1998 committed
111
112
113
        if (
            "DeepseekV2ForCausalLM" in self.hf_config.architectures
            or "DeepseekV3ForCausalLM" in self.hf_config.architectures
114
            or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
HandH1998's avatar
HandH1998 committed
115
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
116
            self.head_dim = 256
117
118
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
119
            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
120
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
121
122
123
124
125
126
127
128
129
130
131
132
            self.v_head_dim = self.hf_config.v_head_dim

            # Handle rope scaling with yarn
            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            if self.hf_config.rope_scaling:
                mscale_all_dim = self.hf_config.rope_scaling.get(
                    "mscale_all_dim", False
                )
                scaling_factor = self.hf_config.rope_scaling["factor"]
                mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
                self.scaling = self.scaling * mscale * mscale

William's avatar
William committed
133
134
135
136
137
        elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 128
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
138
139
140
141
142
        elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 256
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_text_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
143
144
        else:
            self.attention_arch = AttentionArch.MHA
Liangsheng Yin's avatar
Liangsheng Yin committed
145

Liangsheng Yin's avatar
Liangsheng Yin committed
146
147
148
149
        self.num_attention_heads = self.hf_text_config.num_attention_heads
        self.num_key_value_heads = getattr(
            self.hf_text_config, "num_key_value_heads", None
        )
150
151
152
153
154
155
156

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

157
158
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Liangsheng Yin's avatar
Liangsheng Yin committed
159
160
161
        self.hidden_size = self.hf_text_config.hidden_size
        self.num_hidden_layers = self.hf_text_config.num_hidden_layers
        self.vocab_size = self.hf_text_config.vocab_size
Qubitium's avatar
Qubitium committed
162

163
        # Verify quantization
164
165
        self._verify_quantization()

166
        # Cache attributes
167
        self.hf_eos_token_id = self.get_hf_eos_token_id()
168
169
        self.image_token_id = getattr(self.hf_config, "image_token_id", None)

Qubitium's avatar
Qubitium committed
170
171
172
173
174
175
176
177
178
    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
179
180
181
182
183
184
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
185
186
187
188
189
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
190
191
192
193
194
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
195
196
197
198
199
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
Qubitium's avatar
Qubitium committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
226
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
            # compressed-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
        return quant_cfg

    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _verify_quantization(self) -> None:
        supported_quantization = [*QUANTIZATION_METHODS]
        rocm_supported_quantization = [
            "awq",
            "gptq",
            "fp8",
            "compressed_tensors",
            "compressed-tensors",
            "fbgemm_fp8",
HandH1998's avatar
HandH1998 committed
246
            "w8a8_fp8",
247
248
249
250
251
252
253
254
255
256
257
258
        ]
        optimized_quantization_methods = [
            "fp8",
            "marlin",
            "modelopt",
            "gptq_marlin_24",
            "gptq_marlin",
            "awq_marlin",
            "fbgemm_fp8",
            "compressed_tensors",
            "compressed-tensors",
            "experts_int8",
259
            "w8a8_int8",
HandH1998's avatar
HandH1998 committed
260
            "w8a8_fp8",
261
        ]
262
        compatible_quantization_methods = {
HandH1998's avatar
HandH1998 committed
263
264
            "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
            "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
265
        }
266
267
268
269
270
271
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        quant_cfg = self._parse_quant_hf_config()

272
        if quant_cfg is not None:
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
            quant_method = quant_cfg.get("quant_method", "").lower()

            # Detect which checkpoint is it
            for _, method in QUANTIZATION_METHODS.items():
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization
                )
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break

            # Verify quantization configurations.
            if self.quantization is None:
                self.quantization = quant_method
            elif self.quantization != quant_method:
289
290
291
292
293
294
295
296
297
298
299
                if (
                    self.quantization not in compatible_quantization_methods
                    or quant_method
                    not in compatible_quantization_methods[self.quantization]
                ):
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization})."
                    )
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}."
                )
            if is_hip() and self.quantization not in rocm_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm."
                )
            if self.quantization not in optimized_quantization_methods:
                logger.warning(
                    "%s quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.",
                    self.quantization,
                )

320
321
322
323
324
325
326
    def get_hf_eos_token_id(self) -> Optional[Set[int]]:
        eos_ids = getattr(self.hf_config, "eos_token_id", None)
        if eos_ids:
            # it can be either int or list of int
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
        return eos_ids

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    def maybe_pull_model_tokenizer_from_remote(self) -> None:
        """
        Pull the model config files to a temporary
        directory in case of remote.

        Args:
            model: The model name or path.

        """
        from sglang.srt.connector import create_remote_connector
        from sglang.srt.utils import is_remote_url

        if is_remote_url(self.model_path):
            logger.info("Pulling model configs from remote...")
            # BaseConnector implements __del__() to clean up the local dir.
            # Since config files need to exist all the time, so we DO NOT use
            # with statement to avoid closing the client.
            client = create_remote_connector(self.model_path)
            if is_remote_url(self.model_path):
                client.pull_files(allow_pattern=["*config.json"])
                self.model_weights = self.model_path
                self.model_path = client.get_local_dir()

Qubitium's avatar
Qubitium committed
350
351
352

def get_hf_text_config(config: PretrainedConfig):
    """Get the "sub" config relevant to llm for multi modal models.
353
    No op for pure text models.
Qubitium's avatar
Qubitium committed
354
    """
Mingyi's avatar
Mingyi committed
355
356
357
358
    class_name = config.architectures[0]
    if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
        # We support non-hf version of llava models, so we do not want to
        # read the wrong values from the unused default text_config.
359
360
361
        # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
        # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
        setattr(config, "torch_dtype", torch.float16)
Mingyi's avatar
Mingyi committed
362
363
        return config

Qubitium's avatar
Qubitium committed
364
365
366
367
368
369
    if hasattr(config, "text_config"):
        # The code operates under the assumption that text_config should have
        # `num_attention_heads` (among others). Assert here to fail early
        # if transformers config doesn't align with this assumption.
        assert hasattr(config.text_config, "num_attention_heads")
        return config.text_config
370
371
    if hasattr(config, "language_config"):
        return config.language_config
Qubitium's avatar
Qubitium committed
372
373
    else:
        return config
374
375


376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: Union[str, torch.dtype],
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
401
402
403
404
405
                if config.model_type.startswith("gemma"):
                    if config.model_type == "gemma":
                        gemma_version = ""
                    else:
                        gemma_version = config.model_type[5]
406
                    logger.info(
407
                        f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16."
                    )
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
        else:
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
            pass
        else:
            # Casting between float16 and bfloat16 is allowed with a warning.
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

    return torch_dtype


444
445
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
Mick's avatar
Mick committed
446
    # 1. Check the model architecture
447
448
449
450
451
452
453
    # 2. check the `is_embedding` server args

    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
RangiLyu's avatar
RangiLyu committed
454
        or "InternLM2ForRewardModel" in model_architectures
455
        or "Qwen2ForRewardModel" in model_architectures
456
457
458
459
460
461
    ):
        return False
    else:
        return not is_embedding


Mick's avatar
Mick committed
462
multimodal_model_archs = [
463
    "DeepseekVL2ForCausalLM",
Mick's avatar
Mick committed
464
465
466
467
    "LlavaLlamaForCausalLM",
    "LlavaQwenForCausalLM",
    "LlavaMistralForCausalLM",
    "LlavaVidForCausalLM",
468
    "Gemma3ForConditionalGeneration",
Mick's avatar
Mick committed
469
470
471
472
473
474
475
476
477
478
    "Grok1VForCausalLM",
    "Grok1AForCausalLM",
    "MllamaForConditionalGeneration",
    "Qwen2VLForConditionalGeneration",
    "Qwen2_5_VLForConditionalGeneration",
    "MiniCPMV",
    "MultiModalityCausalLM",
]


479
def is_multimodal_model(model_architectures: List[str]):
Mick's avatar
Mick committed
480
481
482
    if any(
        multi_model_arch in model_architectures
        for multi_model_arch in multimodal_model_archs
483
484
485
486
487
488
    ):
        return True
    else:
        return False


489
490
491
492
493
494
495
496
497
498
499
500
def is_multimodal_gen_model(model_architectures: List[str]):
    return False


def is_image_gen_model(model_architectures: List[str]):
    return False


def is_audio_model(model_architectures: List[str]):
    return False


501
502
def is_encoder_decoder_model(model_architectures: List[str]):
    return "MllamaForConditionalGeneration" in model_architectures
503
504
505
506
507
508


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0