starcoder2.py 12.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
21
from typing import Iterable, List, Optional, Tuple
22
23
24

import torch
from torch import nn
25
from transformers import Starcoder2Config
26

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.config import CacheConfig
29
from vllm.distributed import get_tensor_model_parallel_world_size
30
31
32
33
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
37
from vllm.model_executor.layers.rotary_embedding import get_rope
38
39
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors, SamplerOutput
44
45
46
47
48
49


class Starcoder2Attention(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
50
                 cache_config: Optional[CacheConfig] = None,
51
                 quant_config: Optional[QuantizationConfig] = None):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        super().__init__()
        self.config = config

        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta
        self.max_position_embeddings = config.max_position_embeddings
        self.use_bias = config.use_bias

        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.use_bias,
84
            quant_config=quant_config,
85
86
87
88
89
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=self.use_bias,
90
            quant_config=quant_config,
91
92
93
94
95
96
97
98
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
99
100
101
102
103
104
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
105
106
107
108
109

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
110
111
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
112
113
114
115
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
116
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
117
118
119
120
121
122
123
124
        output, _ = self.o_proj(attn_output)
        return output


class Starcoder2MLP(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
125
                 quant_config: Optional[QuantizationConfig] = None):
126
127
128
129
130
        super().__init__()
        self.c_fc = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=config.use_bias,
131
            quant_config=quant_config,
132
133
134
135
136
        )
        self.c_proj = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=config.use_bias,
137
            quant_config=quant_config,
138
        )
139
140
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
141
142
143
144
145
146
147
148
149
150
151
152

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class Starcoder2DecoderLayer(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
153
                 cache_config: Optional[CacheConfig] = None,
154
                 quant_config: Optional[QuantizationConfig] = None):
155
156
        super().__init__()
        self.hidden_size = config.hidden_size
157
158
159
        self.self_attn = Starcoder2Attention(config,
                                             cache_config,
                                             quant_config=quant_config)
160
        self.mlp = Starcoder2MLP(config, quant_config=quant_config)
161
162
163
164
165
166
167
168
169
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.norm_epsilon)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.norm_epsilon)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
170
171
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
172
173
174
175
176
177
178
179
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
180
            attn_metadata=attn_metadata,
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class Starcoder2Model(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
197
                 cache_config: Optional[CacheConfig] = None,
198
                 quant_config: Optional[QuantizationConfig] = None):
199
200
201
202
203
204
205
206
207
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # TODO: consider padding_idx (currently removed)
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.layers = nn.ModuleList([
208
209
210
            Starcoder2DecoderLayer(config,
                                   cache_config,
                                   quant_config=quant_config)
211
212
213
214
215
216
217
218
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
219
220
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
221
222
223
224
225
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(positions, hidden_states, kv_caches[i],
226
                                  attn_metadata)
227
228
229
230
231
232
233
234
        hidden_states = self.norm(hidden_states)
        return hidden_states


class Starcoder2ForCausalLM(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
235
                 cache_config: Optional[CacheConfig] = None,
236
                 quant_config: Optional[QuantizationConfig] = None):
237
238
        super().__init__()
        self.config = config
239
240
241
        self.model = Starcoder2Model(config,
                                     cache_config,
                                     quant_config=quant_config)
242
243
244
        self.vocab_size = config.vocab_size
        self.unpadded_vocab_size = config.vocab_size
        if config.tie_word_embeddings:
245
            self.lm_head = self.model.embed_tokens
246
247
248
249
250
251
252
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
253
                quant_config=quant_config,
254
            )
255
256
257
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = Sampler()
258
259
260
261
262

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
263
264
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
265
        intermediate_tensors: Optional[IntermediateTensors] = None,
266
267
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
268
                                   attn_metadata)
269
270
        return hidden_states

271
272
273
274
275
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
276
        logits = self.logits_processor(self.lm_head, hidden_states,
277
278
279
                                       sampling_metadata)
        return logits

280
281
    def sample(
        self,
282
        logits: Optional[torch.Tensor],
283
284
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
285
        next_tokens = self.sampler(logits, sampling_metadata)
286
287
        return next_tokens

288
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
289
290
291
292
293
294
295
296
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
297
        for name, loaded_weight in weights:
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if self.config.tie_word_embeddings and "lm_head.weight" in name:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)