starcoder2.py 14.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
22
from typing import Iterable, Optional, Set, Tuple, Union
23
24
25

import torch
from torch import nn
26
from transformers import Starcoder2Config
27

28
from vllm.attention import Attention
29
from vllm.compilation.decorators import support_torch_compile
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
33
34
35
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
38
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
39
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
42
43
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
from vllm.sequence import IntermediateTensors
46

47
48
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
49
50
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
51

52
53
54
55
56

class Starcoder2Attention(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
57
                 cache_config: Optional[CacheConfig] = None,
58
59
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        super().__init__()
        self.config = config

        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta
        self.max_position_embeddings = config.max_position_embeddings
        self.use_bias = config.use_bias

        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.use_bias,
92
            quant_config=quant_config,
93
            prefix=f"{prefix}.qkv_proj",
94
95
96
97
98
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=self.use_bias,
99
            quant_config=quant_config,
100
            prefix=f"{prefix}.o_proj",
101
102
103
104
105
106
107
108
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
109
110
111
112
113
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
114
115
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
116
117
118
119
120
121
122
123
124

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
125
        attn_output = self.attn(q, k, v)
126
127
128
129
130
131
132
133
        output, _ = self.o_proj(attn_output)
        return output


class Starcoder2MLP(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
134
135
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
136
137
138
139
140
        super().__init__()
        self.c_fc = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=config.use_bias,
141
            quant_config=quant_config,
142
            prefix=f"{prefix}.c_fc",
143
144
145
146
147
        )
        self.c_proj = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=config.use_bias,
148
            quant_config=quant_config,
149
            prefix=f"{prefix}.c_proj",
150
        )
151
        self.act = get_act_fn(config.hidden_act)
152
153
154
155
156
157
158
159
160
161
162
163

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class Starcoder2DecoderLayer(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
164
                 cache_config: Optional[CacheConfig] = None,
165
166
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
167
168
        super().__init__()
        self.hidden_size = config.hidden_size
169
170
        self.self_attn = Starcoder2Attention(config,
                                             cache_config,
171
172
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.self_attn")
173
174
175
        self.mlp = Starcoder2MLP(config,
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.mlp")
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.norm_epsilon)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.norm_epsilon)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


204
@support_torch_compile
205
206
class Starcoder2Model(nn.Module):

207
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
208
        super().__init__()
209
210
211
212
213

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

214
215
216
217
218
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # TODO: consider padding_idx (currently removed)
219
220
221
222
223
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens")
224
225
226
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Starcoder2DecoderLayer(
227
228
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
229
230
            prefix=f"{prefix}.layers",
        )
231
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
232
233
234
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
235

236
237
238
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

239
240
241
242
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
243
        intermediate_tensors: Optional[IntermediateTensors],
244
        inputs_embeds: Optional[torch.Tensor] = None,
245
246
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
247
248
249
250
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
251
252
253
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
254
255
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(positions, hidden_states)
256
257
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
258
259
260
261
        hidden_states = self.norm(hidden_states)
        return hidden_states


262
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
263

264
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
265
        super().__init__()
266
267
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
268
        self.config = config
269
270
        self.model = Starcoder2Model(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "model"))
271
272
273
        self.vocab_size = config.vocab_size
        self.unpadded_vocab_size = config.vocab_size
        if config.tie_word_embeddings:
274
            self.lm_head = self.model.embed_tokens
275
276
277
278
279
280
281
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
282
                quant_config=quant_config,
283
                prefix=f"{prefix}.lm_head",
284
            )
285
286
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
287
        self.sampler = get_sampler()
288
289
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
290

291
292
293
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

294
295
296
297
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
298
        intermediate_tensors: Optional[IntermediateTensors] = None,
299
        inputs_embeds: Optional[torch.Tensor] = None,
300
    ) -> Union[torch.Tensor, IntermediateTensors]:
301
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
302
                                   inputs_embeds)
303
304
        return hidden_states

305
306
307
308
309
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
310
        logits = self.logits_processor(self.lm_head, hidden_states,
311
312
313
                                       sampling_metadata)
        return logits

314
315
    def sample(
        self,
316
        logits: Optional[torch.Tensor],
317
318
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
319
        next_tokens = self.sampler(logits, sampling_metadata)
320
321
        return next_tokens

322
323
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
324
325
326
327
328
329
330
331
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
332
        loaded_params: Set[str] = set()
333
        for name, loaded_weight in weights:
334
335
336
337
338
339
340
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
341
342
                if is_pp_missing_parameter(name, self):
                    continue
343
344
345
346
347
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
348
349
350
351
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

352
353
                if self.config.tie_word_embeddings and "lm_head.weight" in name:
                    continue
354
355
                if is_pp_missing_parameter(name, self):
                    continue
356
357
358
359
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
360
361
            loaded_params.add(name)
        return loaded_params