stablelm.py 12.7 KB
Newer Older
Hyunsung Lee's avatar
Hyunsung Lee committed
1
# coding=utf-8
2
3
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
Hyunsung Lee's avatar
Hyunsung Lee committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
20
21
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
22
from typing import Iterable, List, Optional, Tuple
Hyunsung Lee's avatar
Hyunsung Lee committed
23

24
25
import torch
from torch import nn
Hyunsung Lee's avatar
Hyunsung Lee committed
26
27
from transformers import PretrainedConfig

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.config import CacheConfig
30
from vllm.distributed import get_tensor_model_parallel_world_size
31
from vllm.model_executor.layers.activation import SiluAndMul
32
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
33
34
                                               QKVParallelLinear,
                                               RowParallelLinear)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
37
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
38
from vllm.model_executor.layers.rotary_embedding import get_rope
39
40
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors, SamplerOutput
Hyunsung Lee's avatar
Hyunsung Lee committed
45

46
47
48
49
50

class StablelmMLP(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
51
                 quant_config: Optional[QuantizationConfig] = None) -> None:
52
53
54
55
56
57
58
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size, [config.intermediate_size] * 2,
            bias=False,
59
            quant_config=quant_config)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        self.down_proj = RowParallelLinear(config.intermediate_size,
                                           config.hidden_size,
                                           bias=False)
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class StablelmAttention(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
76
                 cache_config: Optional[CacheConfig] = None,
77
                 quant_config: Optional[QuantizationConfig] = None) -> None:
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size

        self.total_num_key_value_heads = config.num_key_value_heads
        if self.total_num_key_value_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_key_value_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_key_value_heads == 0
        self.num_key_value_heads = max(
            1, self.total_num_key_value_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
Roy's avatar
Roy committed
98
99
100
        rope_pct = getattr(config, "rope_pct",
                           getattr(config, "partial_rotary_factor", 1))
        self.rotary_ndims = int(self.head_dim * rope_pct)
101
102
103
104
105
        self.scaling = self.head_dim**-0.5
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_key_value_heads * self.head_dim
        self.qkv_bias = getattr(config, "use_qkv_bias", False)
        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
106
107
108
            raise ValueError(f"hidden_size must be divisible by num_heads "
                             f"(got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {self.num_heads}).")
109
110
111
112
113
114

        self.qkv_proj = QKVParallelLinear(self.hidden_size,
                                          self.head_dim,
                                          self.total_num_heads,
                                          self.total_num_key_value_heads,
                                          self.qkv_bias,
115
                                          quant_config=quant_config)
116
117
118
        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        self.hidden_size,
                                        bias=False,
119
                                        quant_config=quant_config)
120
121
122
123
124
125
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.rotary_ndims,
            max_position=self.config.max_position_embeddings,
            base=self.config.rope_theta,
        )
126
127
128
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
129
                              num_kv_heads=self.num_key_value_heads,
130
131
                              cache_config=cache_config,
                              quant_config=quant_config)
132
133
134
135
136

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
137
138
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
139
140
141
142
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
143
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
144
145
146
147
148
        output, _ = self.o_proj(attn_output)
        return output


class StablelmDecoderLayer(nn.Module):
Hyunsung Lee's avatar
Hyunsung Lee committed
149
150
151

    def __init__(
        self,
152
        config: PretrainedConfig,
153
        cache_config: Optional[CacheConfig] = None,
154
        quant_config: Optional[QuantizationConfig] = None,
Hyunsung Lee's avatar
Hyunsung Lee committed
155
    ) -> None:
156
        super().__init__()
157
        self.self_attn = StablelmAttention(config, cache_config, quant_config)
158
        self.mlp = StablelmMLP(config, quant_config)
Roy's avatar
Roy committed
159
160
161
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
162
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
Roy's avatar
Roy committed
163
                                                     eps=norm_eps)
164
165
166
167
168

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
169
170
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
171
172
173
174
175
176
177
178
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
179
            attn_metadata=attn_metadata,
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


class StableLMEpochModel(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
196
                 cache_config: Optional[CacheConfig] = None,
197
                 quant_config: Optional[QuantizationConfig] = None) -> None:
198
199
200
201
202
203
        super().__init__()
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
204
            StablelmDecoderLayer(config, cache_config, quant_config)
205
206
            for _ in range(config.num_hidden_layers)
        ])
Roy's avatar
Roy committed
207
208
209
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
210
211
212
213
214

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
215
216
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
217
218
219
220
221
222
223
224
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
225
                attn_metadata,
226
227
228
229
230
231
232
233
234
235
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class StablelmForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
236
        cache_config: Optional[CacheConfig] = None,
237
        quant_config: Optional[QuantizationConfig] = None,
238
239
240
    ) -> None:
        super().__init__()
        self.config = config
241
        self.quant_config = quant_config
242
        self.model = StableLMEpochModel(config, cache_config, quant_config)
243
244
245
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
246
247
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
248
249
250
251
252

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
253
254
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
255
        intermediate_tensors: Optional[IntermediateTensors] = None,
256
257
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
258
                                   attn_metadata)
259
260
        return hidden_states

261
262
263
264
265
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
266
        logits = self.logits_processor(self.lm_head, hidden_states,
267
268
269
                                       sampling_metadata)
        return logits

270
271
    def sample(
        self,
272
        logits: torch.Tensor,
273
274
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
275
        next_tokens = self.sampler(logits, sampling_metadata)
276
277
        return next_tokens

278
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
279
280
281
282
283
284
285
286
287
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
288
        for name, loaded_weight in weights:
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)