stablelm.py 15.6 KB
Newer Older
1
2
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
Hyunsung Lee's avatar
Hyunsung Lee committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
19
20
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Set, Tuple, Union
Hyunsung Lee's avatar
Hyunsung Lee committed
22

23
24
import torch
from torch import nn
25
from transformers import StableLmConfig
Hyunsung Lee's avatar
Hyunsung Lee committed
26

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.config import CacheConfig, VllmConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import SiluAndMul
31
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
32
33
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
37
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
Hyunsung Lee's avatar
Hyunsung Lee committed
43

44
45
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
46
47
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
48

49
50
51
52

class StablelmMLP(nn.Module):

    def __init__(self,
53
54
55
                 config: StableLmConfig,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
56
57
58
59
60
61
62
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size, [config.intermediate_size] * 2,
            bias=False,
63
64
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
65
66
        self.down_proj = RowParallelLinear(config.intermediate_size,
                                           config.hidden_size,
67
68
69
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
70
71
72
73
74
75
76
77
78
79
80
81
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class StablelmAttention(nn.Module):

    def __init__(self,
82
                 config: StableLmConfig,
83
                 cache_config: Optional[CacheConfig] = None,
84
85
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size

        self.total_num_key_value_heads = config.num_key_value_heads
        if self.total_num_key_value_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_key_value_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_key_value_heads == 0
        self.num_key_value_heads = max(
            1, self.total_num_key_value_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
Roy's avatar
Roy committed
106
107
108
        rope_pct = getattr(config, "rope_pct",
                           getattr(config, "partial_rotary_factor", 1))
        self.rotary_ndims = int(self.head_dim * rope_pct)
109
110
111
112
113
        self.scaling = self.head_dim**-0.5
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_key_value_heads * self.head_dim
        self.qkv_bias = getattr(config, "use_qkv_bias", False)
        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
114
115
116
            raise ValueError(f"hidden_size must be divisible by num_heads "
                             f"(got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {self.num_heads}).")
117
118
119
120
121
122

        self.qkv_proj = QKVParallelLinear(self.hidden_size,
                                          self.head_dim,
                                          self.total_num_heads,
                                          self.total_num_key_value_heads,
                                          self.qkv_bias,
123
124
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.qkv_proj")
125
126
127
        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        self.hidden_size,
                                        bias=False,
128
129
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")
130
131
132
133
134
135
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.rotary_ndims,
            max_position=self.config.max_position_embeddings,
            base=self.config.rope_theta,
        )
136
137
138
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
139
                              num_kv_heads=self.num_key_value_heads,
140
                              cache_config=cache_config,
141
142
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
143
144
145
146
147

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
148
149
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
150
151
152
153
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
154
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
155
156
157
158
159
        output, _ = self.o_proj(attn_output)
        return output


class StablelmDecoderLayer(nn.Module):
Hyunsung Lee's avatar
Hyunsung Lee committed
160
161
162

    def __init__(
        self,
163
        config: StableLmConfig,
164
        cache_config: Optional[CacheConfig] = None,
165
        quant_config: Optional[QuantizationConfig] = None,
166
        prefix: str = "",
Hyunsung Lee's avatar
Hyunsung Lee committed
167
    ) -> None:
168
        super().__init__()
169
170
171
172
        self.self_attn = StablelmAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attn")
173
        self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
Roy's avatar
Roy committed
174
175
176
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
177
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
Roy's avatar
Roy committed
178
                                                     eps=norm_eps)
179
180
181
182
183

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
184
185
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
186
187
188
189
190
191
192
193
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
194
            attn_metadata=attn_metadata,
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


class StableLMEpochModel(nn.Module):

209
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
210
        super().__init__()
211
212
213
214
215

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

216
217
218
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
219
220
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
221
        )
222
223
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
224
225
            lambda prefix: StablelmDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
226
227
            prefix=f"{prefix}.layers",
        )
Roy's avatar
Roy committed
228
229
230
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
231
232
233
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
234

235
236
237
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

238
239
240
241
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
242
243
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
244
        intermediate_tensors: Optional[IntermediateTensors],
245
        inputs_embeds: Optional[torch.Tensor] = None,
246
247
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
248
249
250
251
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
252
253
254
255
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
256
257
258
259
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
260
                kv_caches[i - self.start_layer],
261
                attn_metadata,
262
            )
263
264
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
265
266
267
268
        hidden_states = self.norm(hidden_states)
        return hidden_states


269
class StablelmForCausalLM(nn.Module, SupportsPP):
270

271
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
272
        super().__init__()
273
274
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
275
        self.config = config
276
        self.quant_config = quant_config
277
278
        self.model = StableLMEpochModel(vllm_config=vllm_config,
                                        prefix=maybe_prefix(prefix, "model"))
279
280
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
281
282
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.lm_head")
283
284
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
285
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
286
        self.sampler = get_sampler()
287
288
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
289

290
291
292
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

293
294
295
296
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
297
298
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
299
        intermediate_tensors: Optional[IntermediateTensors] = None,
300
        inputs_embeds: Optional[torch.Tensor] = None,
301
    ) -> Union[torch.Tensor, IntermediateTensors]:
302
        hidden_states = self.model(input_ids, positions, kv_caches,
303
304
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
305
306
        return hidden_states

307
308
309
310
311
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
312
        logits = self.logits_processor(self.lm_head, hidden_states,
313
314
315
                                       sampling_metadata)
        return logits

316
317
    def sample(
        self,
318
        logits: torch.Tensor,
319
320
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
321
        next_tokens = self.sampler(logits, sampling_metadata)
322
323
        return next_tokens

324
325
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
326
327
328
329
330
331
332
333
334
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
335
        loaded_params: Set[str] = set()
336
        for name, loaded_weight in weights:
337
338
339
340
341
342
343
344
345
346
347
348
349
350
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
351
352
                if is_pp_missing_parameter(name, self):
                    continue
353
354
355
356
357
358
359
360
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
361
362
                if is_pp_missing_parameter(name, self):
                    continue
363
364
365
366
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
367
368
            loaded_params.add(name)
        return loaded_params