starcoder2.py 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
20
from typing import Iterable, List, Optional, Tuple, Union
21
22
23

import torch
from torch import nn
24
from transformers import Starcoder2Config
25

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.compilation.decorators import support_torch_compile
28
from vllm.config import CacheConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
30
31
32
33
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.rotary_embedding import get_rope
37
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
43

44
45
46
47
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

48
49
50
51
52

class Starcoder2Attention(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
53
                 cache_config: Optional[CacheConfig] = None,
54
                 quant_config: Optional[QuantizationConfig] = None):
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        super().__init__()
        self.config = config

        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta
        self.max_position_embeddings = config.max_position_embeddings
        self.use_bias = config.use_bias

        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.use_bias,
87
            quant_config=quant_config,
88
89
90
91
92
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=self.use_bias,
93
            quant_config=quant_config,
94
95
96
97
98
99
100
101
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
102
103
104
105
106
107
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
108
109
110
111
112

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
113
114
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
115
116
117
118
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
119
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
120
121
122
123
124
125
126
127
        output, _ = self.o_proj(attn_output)
        return output


class Starcoder2MLP(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
128
                 quant_config: Optional[QuantizationConfig] = None):
129
130
131
132
133
        super().__init__()
        self.c_fc = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=config.use_bias,
134
            quant_config=quant_config,
135
136
137
138
139
        )
        self.c_proj = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=config.use_bias,
140
            quant_config=quant_config,
141
        )
142
143
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
144
145
146
147
148
149
150
151
152
153
154
155

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class Starcoder2DecoderLayer(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
156
                 cache_config: Optional[CacheConfig] = None,
157
                 quant_config: Optional[QuantizationConfig] = None):
158
159
        super().__init__()
        self.hidden_size = config.hidden_size
160
161
162
        self.self_attn = Starcoder2Attention(config,
                                             cache_config,
                                             quant_config=quant_config)
163
        self.mlp = Starcoder2MLP(config, quant_config=quant_config)
164
165
166
167
168
169
170
171
172
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.norm_epsilon)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.norm_epsilon)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
173
174
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
175
176
177
178
179
180
181
182
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
183
            attn_metadata=attn_metadata,
184
185
186
187
188
189
190
191
192
193
194
195
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


196
@support_torch_compile
197
198
199
200
class Starcoder2Model(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
201
                 cache_config: Optional[CacheConfig] = None,
202
203
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
204
205
206
207
208
209
210
211
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # TODO: consider padding_idx (currently removed)
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
212
213
214
215
216
217
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Starcoder2DecoderLayer(
                config, cache_config, quant_config=quant_config),
            prefix=f"{prefix}.layers",
        )
218
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
219
220
221
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
222
223
224
225
226

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
227
228
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
229
230
231
232
233
234
235
236
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
237
            layer = self.layers[i]
238
239
            hidden_states = layer(positions, hidden_states,
                                  kv_caches[i - self.start_layer],
240
                                  attn_metadata)
241
242
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
243
244
245
246
        hidden_states = self.norm(hidden_states)
        return hidden_states


247
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
248
249
250

    def __init__(self,
                 config: Starcoder2Config,
251
                 cache_config: Optional[CacheConfig] = None,
252
                 quant_config: Optional[QuantizationConfig] = None):
253
254
        super().__init__()
        self.config = config
255
256
257
        self.model = Starcoder2Model(config,
                                     cache_config,
                                     quant_config=quant_config)
258
259
260
        self.vocab_size = config.vocab_size
        self.unpadded_vocab_size = config.vocab_size
        if config.tie_word_embeddings:
261
            self.lm_head = self.model.embed_tokens
262
263
264
265
266
267
268
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
269
                quant_config=quant_config,
270
            )
271
272
273
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = Sampler()
274
275
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
276
277
278
279
280

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
281
282
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
283
        intermediate_tensors: Optional[IntermediateTensors] = None,
284
    ) -> Union[torch.Tensor, IntermediateTensors]:
285
        hidden_states = self.model(input_ids, positions, kv_caches,
286
                                   attn_metadata, intermediate_tensors)
287
288
        return hidden_states

289
290
291
292
293
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
294
        logits = self.logits_processor(self.lm_head, hidden_states,
295
296
297
                                       sampling_metadata)
        return logits

298
299
    def sample(
        self,
300
        logits: Optional[torch.Tensor],
301
302
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
303
        next_tokens = self.sampler(logits, sampling_metadata)
304
305
        return next_tokens

306
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
307
308
309
310
311
312
313
314
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
315
        for name, loaded_weight in weights:
316
317
318
319
320
321
322
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
323
324
                if is_pp_missing_parameter(name, self):
                    continue
325
326
327
328
329
330
331
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if self.config.tie_word_embeddings and "lm_head.weight" in name:
                    continue
332
333
                if is_pp_missing_parameter(name, self):
                    continue
334
335
336
337
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)