"vllm/vscode:/vscode.git/clone" did not exist on "781d0562809b34f0c548cd354bbc01c861814f94"
stablelm.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
Hyunsung Lee's avatar
Hyunsung Lee committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
22
23
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
24

25
from collections.abc import Iterable
26
from itertools import islice
Hyunsung Lee's avatar
Hyunsung Lee committed
27

28
29
import torch
from torch import nn
30
from transformers import StableLmConfig
Hyunsung Lee's avatar
Hyunsung Lee committed
31

32
from vllm.attention import Attention
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.activation import SiluAndMul
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
from vllm.model_executor.layers.rotary_embedding import get_rope
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
46
47
    ParallelLMHead,
    VocabParallelEmbedding,
)
48
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
from vllm.sequence import IntermediateTensors
Hyunsung Lee's avatar
Hyunsung Lee committed
50

51
from .interfaces import SupportsPP
52
53
54
55
56
57
58
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
59

60
61

class StablelmMLP(nn.Module):
62
63
64
    def __init__(
        self,
        config: StableLmConfig,
65
        quant_config: QuantizationConfig | None = None,
66
67
        prefix: str = "",
    ) -> None:
68
69
70
71
72
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_up_proj = MergedColumnParallelLinear(
73
74
            config.hidden_size,
            [config.intermediate_size] * 2,
75
            bias=False,
76
            quant_config=quant_config,
77
78
79
80
81
82
83
84
85
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
86
87
88
89
90
91
92
93
94
95
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class StablelmAttention(nn.Module):
96
97
98
    def __init__(
        self,
        config: StableLmConfig,
99
100
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
101
102
        prefix: str = "",
    ) -> None:
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tp_size

        self.total_num_key_value_heads = config.num_key_value_heads
        if self.total_num_key_value_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_key_value_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_key_value_heads == 0
119
        self.num_key_value_heads = max(1, self.total_num_key_value_heads // tp_size)
120
121
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
122
        self.partial_rotary_factor = getattr(
123
124
            config, "rope_pct", getattr(config, "partial_rotary_factor", 1)
        )
125
126
127
128
129
        self.scaling = self.head_dim**-0.5
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_key_value_heads * self.head_dim
        self.qkv_bias = getattr(config, "use_qkv_bias", False)
        if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            raise ValueError(
                f"hidden_size must be divisible by num_heads "
                f"(got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_key_value_heads,
            self.qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
152
153
        self.rotary_emb = get_rope(
            self.head_dim,
154
            rotary_dim=self.head_dim,
155
156
            max_position=self.config.max_position_embeddings,
            base=self.config.rope_theta,
157
            partial_rotary_factor=self.partial_rotary_factor,
158
        )
159
160
161
162
163
164
165
166
167
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
168
169
170
171
172
173
174
175
176

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
177
        attn_output = self.attn(q, k, v)
178
179
180
181
182
        output, _ = self.o_proj(attn_output)
        return output


class StablelmDecoderLayer(nn.Module):
Hyunsung Lee's avatar
Hyunsung Lee committed
183
184
    def __init__(
        self,
185
        config: StableLmConfig,
186
187
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
188
        prefix: str = "",
Hyunsung Lee's avatar
Hyunsung Lee committed
189
    ) -> None:
190
        super().__init__()
191
192
193
        self.self_attn = StablelmAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
        )
194
        self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
195
        norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
Roy's avatar
Roy committed
196
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
197
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
198
199
200
201
202

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
203
    ) -> tuple[torch.Tensor, torch.Tensor]:
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


class StableLMEpochModel(nn.Module):
223
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
224
        super().__init__()
225
226
227
228
229

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

230
231
232
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
233
234
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
235
        )
236
237
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
238
            lambda prefix: StablelmDecoderLayer(
239
240
                config, cache_config, quant_config, prefix=prefix
            ),
241
242
            prefix=f"{prefix}.layers",
        )
243
        norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
Roy's avatar
Roy committed
244
        self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
245
246
247
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
248

249
250
251
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

252
253
254
255
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
256
257
258
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
259
        if get_pp_group().is_first_rank:
260
261
262
263
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
264
265
266
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
267
        for layer in islice(self.layers, self.start_layer, self.end_layer):
268
            hidden_states, residual = layer(positions, hidden_states)
269
270
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
271
272
273
        hidden_states = self.norm(hidden_states)
        return hidden_states

274
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
275
276
277
278
279
280
281
282
283
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
284
        loaded_params: set[str] = set()
285
        for name, loaded_weight in weights:
286
            for param_name, weight_name, shard_id in stacked_params_mapping:
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
306
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
307
308
309
310
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

311

312
class StablelmForCausalLM(nn.Module, SupportsPP):
313
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
314
        super().__init__()
315
316
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
317
        self.config = config
318
        self.quant_config = quant_config
319
320
321
322
323
324
325
326
327
        self.model = StableLMEpochModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.lm_head",
        )
328
329
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
330
        self.logits_processor = LogitsProcessor(config.vocab_size)
331
        self.make_empty_intermediate_tensors = (
332
333
            self.model.make_empty_intermediate_tensors
        )
334

335
336
337
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

338
339
340
341
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
342
343
344
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
345
346
347
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
348
349
        return hidden_states

350
351
352
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
353
    ) -> torch.Tensor | None:
354
        logits = self.logits_processor(self.lm_head, hidden_states)
355
356
        return logits

357
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
358
        loader = AutoWeightsLoader(self)
359
        return loader.load_weights(weights)