model_config.py 18.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

15
import json
16
import logging
17
import math
18
from enum import IntEnum, auto
19
from typing import List, Optional, Set, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
20

21
import torch
Qubitium's avatar
Qubitium committed
22
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
23

24
from sglang.srt.hf_transformers_utils import get_config, get_context_length
25
26
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.utils import get_bool_env_var, is_hip
27

28
29
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
30

31
32
33
34
35
class AttentionArch(IntEnum):
    MLA = auto()
    MHA = auto()


Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
class ModelConfig:
    def __init__(
        self,
39
        model_path: str,
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
42
        context_length: Optional[int] = None,
43
        model_override_args: Optional[str] = None,
Chayenne's avatar
Chayenne committed
44
        is_embedding: Optional[bool] = None,
45
46
        dtype: str = "auto",
        quantization: Optional[str] = None,
47
        override_config_file: Optional[str] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    ) -> None:
49
50
51
        self.model_path = model_path
        self.revision = revision
        self.quantization = quantization
52

53
54
        # Parse args
        self.model_override_args = json.loads(model_override_args)
55
56
57
58
        kwargs = {}
        if override_config_file and override_config_file.strip():
            kwargs["_configuration_file"] = override_config_file.strip()

59
        self.hf_config = get_config(
60
            model_path,
61
62
63
            trust_remote_code=trust_remote_code,
            revision=revision,
            model_override_args=self.model_override_args,
64
            **kwargs,
65
        )
Qubitium's avatar
Qubitium committed
66
        self.hf_text_config = get_hf_text_config(self.hf_config)
67
68

        # Check model type
Chayenne's avatar
Chayenne committed
69
70
71
        self.is_generation = is_generation_model(
            self.hf_config.architectures, is_embedding
        )
72
        self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
73
74
75
        self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
        self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
        self.is_audio_model = is_audio_model(self.hf_config.architectures)
76
        self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
77
        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
78
79

        # Derive context length
80
        derived_context_len = get_context_length(self.hf_text_config)
81
        if context_length is not None:
82
            if context_length > derived_context_len:
83
                if get_bool_env_var(
84
                    "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
85
                ):
86
87
88
89
90
91
92
93
94
95
96
97
98
                    logger.warning(
                        f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                        f"This may lead to incorrect model outputs or CUDA errors."
                    )
                    self.context_len = context_length
                else:
                    raise ValueError(
                        f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
                        f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
                        f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
                    )
            else:
                self.context_len = context_length
99
        else:
100
            self.context_len = derived_context_len
Lianmin Zheng's avatar
Lianmin Zheng committed
101

Liangsheng Yin's avatar
Liangsheng Yin committed
102
        # Unify the config keys for hf_text_config
Liangsheng Yin's avatar
Liangsheng Yin committed
103
        self.head_dim = getattr(
Liangsheng Yin's avatar
Liangsheng Yin committed
104
            self.hf_text_config,
Liangsheng Yin's avatar
Liangsheng Yin committed
105
            "head_dim",
Liangsheng Yin's avatar
Liangsheng Yin committed
106
            self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
Liangsheng Yin's avatar
Liangsheng Yin committed
107
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
108

109
        # FIXME: temporary special judge for MLA architecture
HandH1998's avatar
HandH1998 committed
110
111
112
        if (
            "DeepseekV2ForCausalLM" in self.hf_config.architectures
            or "DeepseekV3ForCausalLM" in self.hf_config.architectures
113
            or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
HandH1998's avatar
HandH1998 committed
114
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
115
            self.head_dim = 256
116
117
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
118
            self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
119
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
120
121
122
123
124
125
126
127
128
129
130
131
            self.v_head_dim = self.hf_config.v_head_dim

            # Handle rope scaling with yarn
            self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            if self.hf_config.rope_scaling:
                mscale_all_dim = self.hf_config.rope_scaling.get(
                    "mscale_all_dim", False
                )
                scaling_factor = self.hf_config.rope_scaling["factor"]
                mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
                self.scaling = self.scaling * mscale * mscale

William's avatar
William committed
132
133
134
135
136
        elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 128
            self.attention_arch = AttentionArch.MLA
            self.kv_lora_rank = self.hf_config.kv_lora_rank
            self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
137
138
        else:
            self.attention_arch = AttentionArch.MHA
Liangsheng Yin's avatar
Liangsheng Yin committed
139

Liangsheng Yin's avatar
Liangsheng Yin committed
140
141
142
143
        self.num_attention_heads = self.hf_text_config.num_attention_heads
        self.num_key_value_heads = getattr(
            self.hf_text_config, "num_key_value_heads", None
        )
144
145
146
147
148
149
150

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

151
152
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Liangsheng Yin's avatar
Liangsheng Yin committed
153
154
155
        self.hidden_size = self.hf_text_config.hidden_size
        self.num_hidden_layers = self.hf_text_config.num_hidden_layers
        self.vocab_size = self.hf_text_config.vocab_size
Qubitium's avatar
Qubitium committed
156

157
        # Verify quantization
158
159
        self._verify_quantization()

160
        # Cache attributes
161
        self.hf_eos_token_id = self.get_hf_eos_token_id()
162
163
        self.image_token_id = getattr(self.hf_config, "image_token_id", None)

Qubitium's avatar
Qubitium committed
164
165
166
167
168
169
170
171
172
    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
173
174
175
176
177
178
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
179
180
181
182
183
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
184
185
186
187
188
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
189
190
191
192
193
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
Qubitium's avatar
Qubitium committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
220
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _parse_quant_hf_config(self):
        quant_cfg = getattr(self.hf_config, "quantization_config", None)
        if quant_cfg is None:
            # compressed-tensors uses a "compression_config" key
            quant_cfg = getattr(self.hf_config, "compression_config", None)
        return quant_cfg

    # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
    def _verify_quantization(self) -> None:
        supported_quantization = [*QUANTIZATION_METHODS]
        rocm_supported_quantization = [
            "awq",
            "gptq",
            "fp8",
            "compressed_tensors",
            "compressed-tensors",
            "fbgemm_fp8",
        ]
        optimized_quantization_methods = [
            "fp8",
            "marlin",
            "modelopt",
            "gptq_marlin_24",
            "gptq_marlin",
            "awq_marlin",
            "fbgemm_fp8",
            "compressed_tensors",
            "compressed-tensors",
            "experts_int8",
252
            "w8a8_int8",
HandH1998's avatar
HandH1998 committed
253
            "w8a8_fp8",
254
        ]
255
        compatible_quantization_methods = {
HandH1998's avatar
HandH1998 committed
256
257
            "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
            "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
258
        }
259
260
261
262
263
264
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        quant_cfg = self._parse_quant_hf_config()

265
        if quant_cfg is not None:
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
            quant_method = quant_cfg.get("quant_method", "").lower()

            # Detect which checkpoint is it
            for _, method in QUANTIZATION_METHODS.items():
                quantization_override = method.override_quantization_method(
                    quant_cfg, self.quantization
                )
                if quantization_override:
                    quant_method = quantization_override
                    self.quantization = quantization_override
                    break

            # Verify quantization configurations.
            if self.quantization is None:
                self.quantization = quant_method
            elif self.quantization != quant_method:
282
283
284
285
286
287
288
289
290
291
292
                if (
                    self.quantization not in compatible_quantization_methods
                    or quant_method
                    not in compatible_quantization_methods[self.quantization]
                ):
                    raise ValueError(
                        "Quantization method specified in the model config "
                        f"({quant_method}) does not match the quantization "
                        f"method specified in the `quantization` argument "
                        f"({self.quantization})."
                    )
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}."
                )
            if is_hip() and self.quantization not in rocm_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm."
                )
            if self.quantization not in optimized_quantization_methods:
                logger.warning(
                    "%s quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.",
                    self.quantization,
                )

313
314
315
316
317
318
319
    def get_hf_eos_token_id(self) -> Optional[Set[int]]:
        eos_ids = getattr(self.hf_config, "eos_token_id", None)
        if eos_ids:
            # it can be either int or list of int
            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
        return eos_ids

Qubitium's avatar
Qubitium committed
320
321
322

def get_hf_text_config(config: PretrainedConfig):
    """Get the "sub" config relevant to llm for multi modal models.
323
    No op for pure text models.
Qubitium's avatar
Qubitium committed
324
    """
Mingyi's avatar
Mingyi committed
325
326
327
328
    class_name = config.architectures[0]
    if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
        # We support non-hf version of llava models, so we do not want to
        # read the wrong values from the unused default text_config.
329
330
331
        # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
        # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
        setattr(config, "torch_dtype", torch.float16)
Mingyi's avatar
Mingyi committed
332
333
        return config

Qubitium's avatar
Qubitium committed
334
335
336
337
338
339
340
341
    if hasattr(config, "text_config"):
        # The code operates under the assumption that text_config should have
        # `num_attention_heads` (among others). Assert here to fail early
        # if transformers config doesn't align with this assumption.
        assert hasattr(config.text_config, "num_attention_heads")
        return config.text_config
    else:
        return config
342
343


344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: Union[str, torch.dtype],
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
                if config.model_type == "gemma2":
                    logger.info(
                        "For Gemma 2, we downcast float32 to bfloat16 instead "
                        "of float16 by default. Please specify `dtype` if you "
                        "want to use float16."
                    )
                    torch_dtype = torch.bfloat16
                else:
                    # Following the common practice, we use float16 for float32
                    # models.
                    torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
        else:
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
            pass
        else:
            # Casting between float16 and bfloat16 is allowed with a warning.
            logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

    return torch_dtype


408
409
410
411
412
413
414
415
416
417
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
    # 1. Check the model architectue
    # 2. check the `is_embedding` server args

    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
RangiLyu's avatar
RangiLyu committed
418
        or "InternLM2ForRewardModel" in model_architectures
419
        or "Qwen2ForRewardModel" in model_architectures
420
421
422
423
424
425
426
427
428
429
430
431
    ):
        return False
    else:
        return not is_embedding


def is_multimodal_model(model_architectures: List[str]):
    if (
        "LlavaLlamaForCausalLM" in model_architectures
        or "LlavaQwenForCausalLM" in model_architectures
        or "LlavaMistralForCausalLM" in model_architectures
        or "LlavaVidForCausalLM" in model_architectures
432
433
        or "Grok1VForCausalLM" in model_architectures
        or "Grok1AForCausalLM" in model_architectures
434
435
        or "MllamaForConditionalGeneration" in model_architectures
        or "Qwen2VLForConditionalGeneration" in model_architectures
Mick's avatar
Mick committed
436
        or "Qwen2_5_VLForConditionalGeneration" in model_architectures
Mick's avatar
Mick committed
437
        or "MiniCPMV" in model_architectures
438
439
440
441
442
443
    ):
        return True
    else:
        return False


444
445
446
447
448
449
450
451
452
453
454
455
def is_multimodal_gen_model(model_architectures: List[str]):
    return False


def is_image_gen_model(model_architectures: List[str]):
    return False


def is_audio_model(model_architectures: List[str]):
    return False


456
457
def is_encoder_decoder_model(model_architectures: List[str]):
    return "MllamaForConditionalGeneration" in model_architectures
458
459
460
461
462
463


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0