stablelm.py 12.1 KB
Newer Older
Hyunsung Lee's avatar
Hyunsung Lee committed
1
# coding=utf-8
2
3
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
Hyunsung Lee's avatar
Hyunsung Lee committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
20
21
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
22
from typing import Iterable, List, Optional, Tuple
Hyunsung Lee's avatar
Hyunsung Lee committed
23

24
25
import torch
from torch import nn
Hyunsung Lee's avatar
Hyunsung Lee committed
26
27
from transformers import PretrainedConfig

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.distributed import get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import SiluAndMul
31
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
32
33
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
37
from vllm.model_executor.layers.rotary_embedding import get_rope
38
39
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
Hyunsung Lee's avatar
Hyunsung Lee committed
44

45
46
47
48
49

class StablelmMLP(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
50
                 quant_config: Optional[QuantizationConfig] = None) -> None:
51
52
53
54
55
56
57
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size, [config.intermediate_size] * 2,
            bias=False,
58
            quant_config=quant_config)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        self.down_proj = RowParallelLinear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False)
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class StablelmAttention(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
75
                 quant_config: Optional[QuantizationConfig] = None) -> None:
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size

        self.total_num_key_value_heads = config.num_key_value_heads
        if self.total_num_key_value_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_key_value_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_key_value_heads == 0
        self.num_key_value_heads = max(
            1, self.total_num_key_value_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
Roy's avatar
Roy committed
96
97
98
        rope_pct = getattr(config, "rope_pct",
                           getattr(config, "partial_rotary_factor", 1))
        self.rotary_ndims = int(self.head_dim * rope_pct)
99
100
101
102
103
        self.scaling = self.head_dim**-0.5
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_key_value_heads * self.head_dim
        self.qkv_bias = getattr(config, "use_qkv_bias", False)
        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
104
105
106
            raise ValueError(f"hidden_size must be divisible by num_heads "
                             f"(got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {self.num_heads}).")
107
108
109
110
111
112

        self.qkv_proj = QKVParallelLinear(self.hidden_size,
                                          self.head_dim,
                                          self.total_num_heads,
                                          self.total_num_key_value_heads,
                                          self.qkv_bias,
113
                                          quant_config=quant_config)
114
115
116
        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        self.hidden_size,
                                        bias=False,
117
                                        quant_config=quant_config)
118
119
120
121
122
123
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.rotary_ndims,
            max_position=self.config.max_position_embeddings,
            base=self.config.rope_theta,
        )
124
125
126
127
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_key_value_heads)
128
129
130
131
132

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
133
134
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
135
136
137
138
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
139
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
140
141
142
143
144
        output, _ = self.o_proj(attn_output)
        return output


class StablelmDecoderLayer(nn.Module):
Hyunsung Lee's avatar
Hyunsung Lee committed
145
146
147

    def __init__(
        self,
148
        config: PretrainedConfig,
149
        quant_config: Optional[QuantizationConfig] = None,
Hyunsung Lee's avatar
Hyunsung Lee committed
150
    ) -> None:
151
152
        super().__init__()
        self.self_attn = StablelmAttention(config)
153
        self.mlp = StablelmMLP(config, quant_config)
Roy's avatar
Roy committed
154
155
156
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
157
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
Roy's avatar
Roy committed
158
                                                     eps=norm_eps)
159
160
161
162
163

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
164
165
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
166
167
168
169
170
171
172
173
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
174
            attn_metadata=attn_metadata,
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


class StableLMEpochModel(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
191
                 quant_config: Optional[QuantizationConfig] = None) -> None:
192
193
194
195
196
197
        super().__init__()
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
198
            StablelmDecoderLayer(config, quant_config)
199
200
            for _ in range(config.num_hidden_layers)
        ])
Roy's avatar
Roy committed
201
202
203
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
204
205
206
207
208

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
209
210
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
211
212
213
214
215
216
217
218
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
219
                attn_metadata,
220
221
222
223
224
225
226
227
228
229
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class StablelmForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
230
        quant_config: Optional[QuantizationConfig] = None,
231
232
233
    ) -> None:
        super().__init__()
        self.config = config
234
235
        self.quant_config = quant_config
        self.model = StableLMEpochModel(config, quant_config)
236
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
237
238
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
239
240
241
242
243

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
244
245
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
246
247
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
248
                                   attn_metadata)
249
250
        return hidden_states

251
252
253
254
255
256
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

257
258
    def sample(
        self,
259
        logits: torch.Tensor,
260
261
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
262
        next_tokens = self.sampler(logits, sampling_metadata)
263
264
        return next_tokens

265
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
266
267
268
269
270
271
272
273
274
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
275
        for name, loaded_weight in weights:
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)