stablelm.py 14.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
Hyunsung Lee's avatar
Hyunsung Lee committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
21
22
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
23
from typing import Iterable, Optional, Set, Tuple, Union
Hyunsung Lee's avatar
Hyunsung Lee committed
24

25
26
import torch
from torch import nn
27
from transformers import StableLmConfig
Hyunsung Lee's avatar
Hyunsung Lee committed
28

29
from vllm.attention import Attention
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
from vllm.model_executor.layers.activation import SiluAndMul
33
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
34
35
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
38
from vllm.model_executor.layers.rotary_embedding import get_rope
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
Hyunsung Lee's avatar
Hyunsung Lee committed
44

45
from .interfaces import SupportsPP
46
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
49

50
51
52
53

class StablelmMLP(nn.Module):

    def __init__(self,
54
55
56
                 config: StableLmConfig,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
57
58
59
60
61
62
63
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size, [config.intermediate_size] * 2,
            bias=False,
64
65
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
66
67
        self.down_proj = RowParallelLinear(config.intermediate_size,
                                           config.hidden_size,
68
69
70
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
71
72
73
74
75
76
77
78
79
80
81
82
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class StablelmAttention(nn.Module):

    def __init__(self,
83
                 config: StableLmConfig,
84
                 cache_config: Optional[CacheConfig] = None,
85
86
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size

        self.total_num_key_value_heads = config.num_key_value_heads
        if self.total_num_key_value_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_key_value_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_key_value_heads == 0
        self.num_key_value_heads = max(
            1, self.total_num_key_value_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
107
108
        self.partial_rotary_factor = getattr(
            config, "rope_pct", getattr(config, "partial_rotary_factor", 1))
109
110
111
112
113
        self.scaling = self.head_dim**-0.5
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_key_value_heads * self.head_dim
        self.qkv_bias = getattr(config, "use_qkv_bias", False)
        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
114
115
116
            raise ValueError(f"hidden_size must be divisible by num_heads "
                             f"(got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {self.num_heads}).")
117
118
119
120
121
122

        self.qkv_proj = QKVParallelLinear(self.hidden_size,
                                          self.head_dim,
                                          self.total_num_heads,
                                          self.total_num_key_value_heads,
                                          self.qkv_bias,
123
124
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.qkv_proj")
125
126
127
        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        self.hidden_size,
                                        bias=False,
128
129
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")
130
131
        self.rotary_emb = get_rope(
            self.head_dim,
132
            rotary_dim=self.head_dim,
133
134
            max_position=self.config.max_position_embeddings,
            base=self.config.rope_theta,
135
            partial_rotary_factor=self.partial_rotary_factor,
136
        )
137
138
139
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
140
                              num_kv_heads=self.num_key_value_heads,
141
                              cache_config=cache_config,
142
143
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
144
145
146
147
148
149
150
151
152

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
153
        attn_output = self.attn(q, k, v)
154
155
156
157
158
        output, _ = self.o_proj(attn_output)
        return output


class StablelmDecoderLayer(nn.Module):
Hyunsung Lee's avatar
Hyunsung Lee committed
159
160
161

    def __init__(
        self,
162
        config: StableLmConfig,
163
        cache_config: Optional[CacheConfig] = None,
164
        quant_config: Optional[QuantizationConfig] = None,
165
        prefix: str = "",
Hyunsung Lee's avatar
Hyunsung Lee committed
166
    ) -> None:
167
        super().__init__()
168
169
170
171
        self.self_attn = StablelmAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attn")
172
        self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
Roy's avatar
Roy committed
173
174
175
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
176
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
Roy's avatar
Roy committed
177
                                                     eps=norm_eps)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


class StableLMEpochModel(nn.Module):

204
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
205
        super().__init__()
206
207
208
209
210

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

211
212
213
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
214
215
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
216
        )
217
218
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
219
220
            lambda prefix: StablelmDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
221
222
            prefix=f"{prefix}.layers",
        )
Roy's avatar
Roy committed
223
224
225
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
226
227
228
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
229

230
231
232
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

233
234
235
236
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
237
        intermediate_tensors: Optional[IntermediateTensors],
238
        inputs_embeds: Optional[torch.Tensor] = None,
239
240
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
241
242
243
244
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
245
246
247
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
248
249
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states)
250
251
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
252
253
254
        hidden_states = self.norm(hidden_states)
        return hidden_states

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

294

295
class StablelmForCausalLM(nn.Module, SupportsPP):
296

297
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
298
        super().__init__()
299
300
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
301
        self.config = config
302
        self.quant_config = quant_config
303
304
        self.model = StableLMEpochModel(vllm_config=vllm_config,
                                        prefix=maybe_prefix(prefix, "model"))
305
306
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
307
308
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.lm_head")
309
310
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
311
        self.logits_processor = LogitsProcessor(config.vocab_size)
312
313
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
314

315
316
317
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

318
319
320
321
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
322
        intermediate_tensors: Optional[IntermediateTensors] = None,
323
        inputs_embeds: Optional[torch.Tensor] = None,
324
    ) -> Union[torch.Tensor, IntermediateTensors]:
325
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
326
                                   inputs_embeds)
327
328
        return hidden_states

329
330
331
332
333
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
334
        logits = self.logits_processor(self.lm_head, hidden_states,
335
336
337
                                       sampling_metadata)
        return logits

338
339
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
340
341
342
343
344
345
346
347
348
349
        loader = AutoWeightsLoader(
            self,
            # Models trained using ColossalAI may include these tensors in
            # the checkpoint. Skip them.
            skip_prefixes=[
                "rotary_emb.inv_freq", "rotary_emb.cos_cached",
                "rotary_emb.sin_cached"
            ],
        )
        return loader.load_weights(weights)