model_config.py 5.33 KB
Newer Older
Liangsheng Yin's avatar
Liangsheng Yin committed
1
from typing import Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
2

Qubitium's avatar
Qubitium committed
3
from transformers import PretrainedConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
4

5
6
from sglang.srt.hf_transformers_utils import get_config, get_context_length

Lianmin Zheng's avatar
Lianmin Zheng committed
7
8
9
10
11
12
13

class ModelConfig:
    def __init__(
        self,
        path: str,
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
14
        context_length: Optional[int] = None,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
15
        model_overide_args: Optional[dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
    ) -> None:
        self.path = path
        self.trust_remote_code = trust_remote_code
        self.revision = revision
20
        self.model_overide_args = model_overide_args
21
22
23
24
25
26
        self.hf_config = get_config(
            self.path,
            trust_remote_code,
            revision,
            model_overide_args=model_overide_args,
        )
Qubitium's avatar
Qubitium committed
27
        self.hf_text_config = get_hf_text_config(self.hf_config)
28
29
30
31
        if context_length is not None:
            self.context_len = context_length
        else:
            self.context_len = get_context_length(self.hf_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33

        # Unify the config keys for hf_config
Liangsheng Yin's avatar
Liangsheng Yin committed
34
35
36
37
38
        self.head_dim = getattr(
            self.hf_config,
            "head_dim",
            self.hf_config.hidden_size // self.hf_config.num_attention_heads,
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
39
40
41
42
43

        # FIXME: temporary special judge for deepseek v2 MLA architecture
        if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
            self.head_dim = 256

Lianmin Zheng's avatar
Lianmin Zheng committed
44
        self.num_attention_heads = self.hf_config.num_attention_heads
45
        self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
46
47
48
49
50
51
52

        # for Dbrx and MPT models
        if self.hf_config.model_type in ["dbrx", "mpt"]:
            self.num_key_value_heads = getattr(
                self.hf_config.attn_config, "kv_n_heads", None
            )

53
54
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56
        self.hidden_size = self.hf_config.hidden_size
        self.num_hidden_layers = self.hf_config.num_hidden_layers
Qubitium's avatar
Qubitium committed
57
58
59
60
61
62
63
64
65
66
67
        self.vocab_size = self.hf_config.vocab_size

    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
        # For GPTBigCode & Falcon:
        # NOTE: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
        new_decoder_arch_falcon = (
68
69
70
71
72
73
            self.hf_config.model_type in falcon_model_types
            and getattr(self.hf_config, "new_decoder_architecture", False)
        )
        if not new_decoder_arch_falcon and getattr(
            self.hf_text_config, "multi_query", False
        ):
Qubitium's avatar
Qubitium committed
74
75
76
77
78
            # Multi-query attention, only one KV head.
            # Currently, tensor parallelism is not supported in this case.
            return 1

        # For DBRX and MPT
79
80
81
82
83
        if self.hf_config.model_type in ["mpt"]:
            if "kv_n_heads" in self.hf_config.attn_config:
                return self.hf_config.attn_config["kv_n_heads"]
            return self.hf_config.num_attention_heads
        if self.hf_config.model_type in ["dbrx"]:
84
85
86
87
88
            return getattr(
                self.hf_config.attn_config,
                "kv_n_heads",
                self.hf_config.num_attention_heads,
            )
Qubitium's avatar
Qubitium committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_text_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_text_config.num_attention_heads

    # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
    def get_num_kv_heads(self, tensor_parallel_size) -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
116
        return max(1, total_num_kv_heads // tensor_parallel_size)
Qubitium's avatar
Qubitium committed
117
118
119
120


def get_hf_text_config(config: PretrainedConfig):
    """Get the "sub" config relevant to llm for multi modal models.
121
    No op for pure text models.
Qubitium's avatar
Qubitium committed
122
    """
Mingyi's avatar
Mingyi committed
123
124
125
126
127
128
    class_name = config.architectures[0]
    if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
        # We support non-hf version of llava models, so we do not want to
        # read the wrong values from the unused default text_config.
        return config

Qubitium's avatar
Qubitium committed
129
130
131
132
133
134
135
136
    if hasattr(config, "text_config"):
        # The code operates under the assumption that text_config should have
        # `num_attention_heads` (among others). Assert here to fail early
        # if transformers config doesn't align with this assumption.
        assert hasattr(config.text_config, "num_attention_heads")
        return config.text_config
    else:
        return config