"tests/vscode:/vscode.git/clone" did not exist on "2461d9e562e5852555c76e0dbed06979f9c6c688"
stablelm.py 15.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
Hyunsung Lee's avatar
Hyunsung Lee committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
21
22
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
23
from typing import Iterable, List, Optional, Set, Tuple, Union
Hyunsung Lee's avatar
Hyunsung Lee committed
24

25
26
import torch
from torch import nn
27
from transformers import StableLmConfig
Hyunsung Lee's avatar
Hyunsung Lee committed
28

29
from vllm.attention import Attention, AttentionMetadata
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
from vllm.model_executor.layers.activation import SiluAndMul
33
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
34
35
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
38
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
39
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors
Hyunsung Lee's avatar
Hyunsung Lee committed
45

46
47
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
48
49
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
50

51
52
53
54

class StablelmMLP(nn.Module):

    def __init__(self,
55
56
57
                 config: StableLmConfig,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
58
59
60
61
62
63
64
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size, [config.intermediate_size] * 2,
            bias=False,
65
66
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
67
68
        self.down_proj = RowParallelLinear(config.intermediate_size,
                                           config.hidden_size,
69
70
71
                                           bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
72
73
74
75
76
77
78
79
80
81
82
83
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class StablelmAttention(nn.Module):

    def __init__(self,
84
                 config: StableLmConfig,
85
                 cache_config: Optional[CacheConfig] = None,
86
87
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size

        self.total_num_key_value_heads = config.num_key_value_heads
        if self.total_num_key_value_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_key_value_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_key_value_heads == 0
        self.num_key_value_heads = max(
            1, self.total_num_key_value_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
Roy's avatar
Roy committed
108
109
110
        rope_pct = getattr(config, "rope_pct",
                           getattr(config, "partial_rotary_factor", 1))
        self.rotary_ndims = int(self.head_dim * rope_pct)
111
112
113
114
115
        self.scaling = self.head_dim**-0.5
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_key_value_heads * self.head_dim
        self.qkv_bias = getattr(config, "use_qkv_bias", False)
        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
116
117
118
            raise ValueError(f"hidden_size must be divisible by num_heads "
                             f"(got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {self.num_heads}).")
119
120
121
122
123
124

        self.qkv_proj = QKVParallelLinear(self.hidden_size,
                                          self.head_dim,
                                          self.total_num_heads,
                                          self.total_num_key_value_heads,
                                          self.qkv_bias,
125
126
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.qkv_proj")
127
128
129
        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                        self.hidden_size,
                                        bias=False,
130
131
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.o_proj")
132
133
134
135
136
137
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.rotary_ndims,
            max_position=self.config.max_position_embeddings,
            base=self.config.rope_theta,
        )
138
139
140
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
141
                              num_kv_heads=self.num_key_value_heads,
142
                              cache_config=cache_config,
143
144
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
145
146
147
148
149

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
150
151
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
152
153
154
155
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
156
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
157
158
159
160
161
        output, _ = self.o_proj(attn_output)
        return output


class StablelmDecoderLayer(nn.Module):
Hyunsung Lee's avatar
Hyunsung Lee committed
162
163
164

    def __init__(
        self,
165
        config: StableLmConfig,
166
        cache_config: Optional[CacheConfig] = None,
167
        quant_config: Optional[QuantizationConfig] = None,
168
        prefix: str = "",
Hyunsung Lee's avatar
Hyunsung Lee committed
169
    ) -> None:
170
        super().__init__()
171
172
173
174
        self.self_attn = StablelmAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attn")
175
        self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
Roy's avatar
Roy committed
176
177
178
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
179
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
Roy's avatar
Roy committed
180
                                                     eps=norm_eps)
181
182
183
184
185

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
186
187
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
188
189
190
191
192
193
194
195
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
196
            attn_metadata=attn_metadata,
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


class StableLMEpochModel(nn.Module):

211
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
212
        super().__init__()
213
214
215
216
217

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

218
219
220
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
221
222
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
223
        )
224
225
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
226
227
            lambda prefix: StablelmDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
228
229
            prefix=f"{prefix}.layers",
        )
Roy's avatar
Roy committed
230
231
232
        norm_eps = getattr(config, "norm_eps",
                           getattr(config, "layer_norm_eps", 1e-05))
        self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
233
234
235
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
236

237
238
239
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

240
241
242
243
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
244
245
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
246
        intermediate_tensors: Optional[IntermediateTensors],
247
        inputs_embeds: Optional[torch.Tensor] = None,
248
249
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
250
251
252
253
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
254
255
256
257
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
258
259
260
261
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
262
                kv_caches[i - self.start_layer],
263
                attn_metadata,
264
            )
265
266
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
267
268
269
270
        hidden_states = self.norm(hidden_states)
        return hidden_states


271
class StablelmForCausalLM(nn.Module, SupportsPP):
272

273
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
274
        super().__init__()
275
276
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
277
        self.config = config
278
        self.quant_config = quant_config
279
280
        self.model = StableLMEpochModel(vllm_config=vllm_config,
                                        prefix=maybe_prefix(prefix, "model"))
281
282
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
283
284
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.lm_head")
285
286
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
287
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
288
        self.sampler = get_sampler()
289
290
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
291

292
293
294
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

295
296
297
298
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
299
300
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
301
        intermediate_tensors: Optional[IntermediateTensors] = None,
302
        inputs_embeds: Optional[torch.Tensor] = None,
303
    ) -> Union[torch.Tensor, IntermediateTensors]:
304
        hidden_states = self.model(input_ids, positions, kv_caches,
305
306
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
307
308
        return hidden_states

309
310
311
312
313
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
314
        logits = self.logits_processor(self.lm_head, hidden_states,
315
316
317
                                       sampling_metadata)
        return logits

318
319
    def sample(
        self,
320
        logits: torch.Tensor,
321
322
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
323
        next_tokens = self.sampler(logits, sampling_metadata)
324
325
        return next_tokens

326
327
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
328
329
330
331
332
333
334
335
336
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
337
        loaded_params: Set[str] = set()
338
        for name, loaded_weight in weights:
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
353
354
                if is_pp_missing_parameter(name, self):
                    continue
355
356
357
358
359
360
361
362
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
363
364
                if is_pp_missing_parameter(name, self):
                    continue
365
366
367
368
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
369
370
            loaded_params.add(name)
        return loaded_params