starcoder2.py 14.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
20
from typing import Iterable, List, Optional, Set, Tuple, Union
21
22
23

import torch
from torch import nn
24
from transformers import Starcoder2Config
25

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.compilation.decorators import support_torch_compile
28
from vllm.config import CacheConfig, VllmConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
30
31
32
33
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
37
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
43

44
45
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
46
47
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
48

49
50
51
52
53

class Starcoder2Attention(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
54
                 cache_config: Optional[CacheConfig] = None,
55
56
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        super().__init__()
        self.config = config

        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta
        self.max_position_embeddings = config.max_position_embeddings
        self.use_bias = config.use_bias

        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.use_bias,
89
            quant_config=quant_config,
90
91
92
93
94
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=self.use_bias,
95
            quant_config=quant_config,
96
97
98
99
100
101
102
103
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
104
105
106
107
108
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
109
110
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
111
112
113
114
115

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
116
117
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
118
119
120
121
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
122
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
123
124
125
126
127
128
129
130
        output, _ = self.o_proj(attn_output)
        return output


class Starcoder2MLP(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
131
                 quant_config: Optional[QuantizationConfig] = None):
132
133
134
135
136
        super().__init__()
        self.c_fc = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=config.use_bias,
137
            quant_config=quant_config,
138
139
140
141
142
        )
        self.c_proj = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=config.use_bias,
143
            quant_config=quant_config,
144
        )
145
        self.act = get_act_fn(config.hidden_act)
146
147
148
149
150
151
152
153
154
155
156
157

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class Starcoder2DecoderLayer(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
158
                 cache_config: Optional[CacheConfig] = None,
159
160
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
161
162
        super().__init__()
        self.hidden_size = config.hidden_size
163
164
        self.self_attn = Starcoder2Attention(config,
                                             cache_config,
165
166
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.self_attn")
167
        self.mlp = Starcoder2MLP(config, quant_config=quant_config)
168
169
170
171
172
173
174
175
176
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.norm_epsilon)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.norm_epsilon)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
177
178
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
179
180
181
182
183
184
185
186
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
187
            attn_metadata=attn_metadata,
188
189
190
191
192
193
194
195
196
197
198
199
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


200
@support_torch_compile
201
202
class Starcoder2Model(nn.Module):

203
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
204
        super().__init__()
205
206
207
208
209

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

210
211
212
213
214
215
216
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # TODO: consider padding_idx (currently removed)
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
217
218
219
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Starcoder2DecoderLayer(
220
221
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
222
223
            prefix=f"{prefix}.layers",
        )
224
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
225
226
227
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
228

229
230
231
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

232
233
234
235
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
236
237
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
238
        intermediate_tensors: Optional[IntermediateTensors],
239
        inputs_embeds: Optional[torch.Tensor] = None,
240
241
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
242
243
244
245
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
246
247
248
249
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
250
            layer = self.layers[i]
251
252
            hidden_states = layer(positions, hidden_states,
                                  kv_caches[i - self.start_layer],
253
                                  attn_metadata)
254
255
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
256
257
258
259
        hidden_states = self.norm(hidden_states)
        return hidden_states


260
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
261

262
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
263
        super().__init__()
264
265
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
266
        self.config = config
267
268
        self.model = Starcoder2Model(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "model"))
269
270
271
        self.vocab_size = config.vocab_size
        self.unpadded_vocab_size = config.vocab_size
        if config.tie_word_embeddings:
272
            self.lm_head = self.model.embed_tokens
273
274
275
276
277
278
279
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
280
                quant_config=quant_config,
281
            )
282
283
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
284
        self.sampler = get_sampler()
285
286
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
287

288
289
290
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

291
292
293
294
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
295
296
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
297
        intermediate_tensors: Optional[IntermediateTensors] = None,
298
        inputs_embeds: Optional[torch.Tensor] = None,
299
    ) -> Union[torch.Tensor, IntermediateTensors]:
300
        hidden_states = self.model(input_ids, positions, kv_caches,
301
302
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
303
304
        return hidden_states

305
306
307
308
309
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
310
        logits = self.logits_processor(self.lm_head, hidden_states,
311
312
313
                                       sampling_metadata)
        return logits

314
315
    def sample(
        self,
316
        logits: Optional[torch.Tensor],
317
318
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
319
        next_tokens = self.sampler(logits, sampling_metadata)
320
321
        return next_tokens

322
323
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
324
325
326
327
328
329
330
331
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
332
        loaded_params: Set[str] = set()
333
        for name, loaded_weight in weights:
334
335
336
337
338
339
340
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
341
342
                if is_pp_missing_parameter(name, self):
                    continue
343
344
345
346
347
348
349
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if self.config.tie_word_embeddings and "lm_head.weight" in name:
                    continue
350
351
                if is_pp_missing_parameter(name, self):
                    continue
352
353
354
355
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
356
357
            loaded_params.add(name)
        return loaded_params