stablelm.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
Hyunsung Lee's avatar
Hyunsung Lee committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
22
23
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
24
from collections.abc import Iterable
25
from itertools import islice
26
from typing import Optional, Union
Hyunsung Lee's avatar
Hyunsung Lee committed
27

28
29
import torch
from torch import nn
30
from transformers import StableLmConfig
Hyunsung Lee's avatar
Hyunsung Lee committed
31

32
from vllm.attention import Attention
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.activation import SiluAndMul
36
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
    ParallelLMHead, VocabParallelEmbedding)
44
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
from vllm.sequence import IntermediateTensors
Hyunsung Lee's avatar
Hyunsung Lee committed
47

48
from .interfaces import SupportsPP
49
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
50
51
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
52

53
54
55
56

class StablelmMLP(nn.Module):

    def __init__(self,
57
58
59
                 config: StableLmConfig,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
60
61
62
63
64
65
66
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size, [config.intermediate_size] * 2,
            bias=False,
67
68
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
69
70
        self.down_proj = RowParallelLinear(config.intermediate_size,
                                           config.hidden_size,
71
72
73
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
74
75
76
77
78
79
80
81
82
83
84
85
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class StablelmAttention(nn.Module):

    def __init__(self,
86
                 config: StableLmConfig,
87
                 cache_config: Optional[CacheConfig] = None,
88
89
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size

        self.total_num_key_value_heads = config.num_key_value_heads
        if self.total_num_key_value_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_key_value_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_key_value_heads == 0
        self.num_key_value_heads = max(
            1, self.total_num_key_value_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
110
111
        self.partial_rotary_factor = getattr(
            config, "rope_pct", getattr(config, "partial_rotary_factor", 1))
112
113
114
115
116
        self.scaling = self.head_dim**-0.5
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_key_value_heads * self.head_dim
        self.qkv_bias = getattr(config, "use_qkv_bias", False)
        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
117
118
119
            raise ValueError(f"hidden_size must be divisible by num_heads "
                             f"(got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {self.num_heads}).")
120
121
122
123
124
125

        self.qkv_proj = QKVParallelLinear(self.hidden_size,
                                          self.head_dim,
                                          self.total_num_heads,
                                          self.total_num_key_value_heads,
                                          self.qkv_bias,
126
127
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.qkv_proj")
128
129
130
        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        self.hidden_size,
                                        bias=False,
131
132
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")
133
134
        self.rotary_emb = get_rope(
            self.head_dim,
135
            rotary_dim=self.head_dim,
136
137
            max_position=self.config.max_position_embeddings,
            base=self.config.rope_theta,
138
            partial_rotary_factor=self.partial_rotary_factor,
139
        )
140
141
142
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
143
                              num_kv_heads=self.num_key_value_heads,
144
                              cache_config=cache_config,
145
146
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
147
148
149
150
151
152
153
154
155

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
156
        attn_output = self.attn(q, k, v)
157
158
159
160
161
        output, _ = self.o_proj(attn_output)
        return output


class StablelmDecoderLayer(nn.Module):
Hyunsung Lee's avatar
Hyunsung Lee committed
162
163
164

    def __init__(
        self,
165
        config: StableLmConfig,
166
        cache_config: Optional[CacheConfig] = None,
167
        quant_config: Optional[QuantizationConfig] = None,
168
        prefix: str = "",
Hyunsung Lee's avatar
Hyunsung Lee committed
169
    ) -> None:
170
        super().__init__()
171
172
173
174
        self.self_attn = StablelmAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attn")
175
        self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
Roy's avatar
Roy committed
176
177
178
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
179
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
Roy's avatar
Roy committed
180
                                                     eps=norm_eps)
181
182
183
184
185

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
186
    ) -> tuple[torch.Tensor, torch.Tensor]:
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


class StableLMEpochModel(nn.Module):

207
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
208
        super().__init__()
209
210
211
212
213

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

214
215
216
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
217
218
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
219
        )
220
221
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
222
223
            lambda prefix: StablelmDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
224
225
            prefix=f"{prefix}.layers",
        )
Roy's avatar
Roy committed
226
227
228
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
229
230
231
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
232

233
234
235
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

236
237
238
239
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
240
        intermediate_tensors: Optional[IntermediateTensors],
241
        inputs_embeds: Optional[torch.Tensor] = None,
242
243
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
244
245
246
247
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
248
249
250
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
251
        for layer in islice(self.layers, self.start_layer, self.end_layer):
252
            hidden_states, residual = layer(positions, hidden_states)
253
254
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
255
256
257
        hidden_states = self.norm(hidden_states)
        return hidden_states

258
259
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
260
261
262
263
264
265
266
267
268
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
269
        loaded_params: set[str] = set()
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

297

298
class StablelmForCausalLM(nn.Module, SupportsPP):
299

300
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
301
        super().__init__()
302
303
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
304
        self.config = config
305
        self.quant_config = quant_config
306
307
        self.model = StableLMEpochModel(vllm_config=vllm_config,
                                        prefix=maybe_prefix(prefix, "model"))
308
309
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
310
311
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.lm_head")
312
313
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
314
        self.logits_processor = LogitsProcessor(config.vocab_size)
315
316
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
317

318
319
320
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

321
322
323
324
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
325
        intermediate_tensors: Optional[IntermediateTensors] = None,
326
        inputs_embeds: Optional[torch.Tensor] = None,
327
    ) -> Union[torch.Tensor, IntermediateTensors]:
328
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
329
                                   inputs_embeds)
330
331
        return hidden_states

332
333
334
335
336
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
337
        logits = self.logits_processor(self.lm_head, hidden_states,
338
339
340
                                       sampling_metadata)
        return logits

341
342
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
343
        loader = AutoWeightsLoader(self)
344
        return loader.load_weights(weights)