starcoder2.py 14.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
23
24
from collections.abc import Iterable
from typing import Optional, Union
25
26
27

import torch
from torch import nn
28
from transformers import Starcoder2Config
29

30
from vllm.attention import Attention
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
35
36
37
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.rotary_embedding import get_rope
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
43
44
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
from vllm.sequence import IntermediateTensors
47

48
from .interfaces import SupportsPP
49
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
50
51
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
52

53
54
55
56
57

class Starcoder2Attention(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
58
                 cache_config: Optional[CacheConfig] = None,
59
60
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        super().__init__()
        self.config = config

        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta
        self.max_position_embeddings = config.max_position_embeddings
        self.use_bias = config.use_bias

        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.use_bias,
93
            quant_config=quant_config,
94
            prefix=f"{prefix}.qkv_proj",
95
96
97
98
99
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=self.use_bias,
100
            quant_config=quant_config,
101
            prefix=f"{prefix}.o_proj",
102
103
104
105
106
107
108
109
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
110
111
112
113
114
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
115
116
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
117
118
119
120
121
122
123
124
125

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
126
        attn_output = self.attn(q, k, v)
127
128
129
130
131
132
133
134
        output, _ = self.o_proj(attn_output)
        return output


class Starcoder2MLP(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
135
136
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
137
138
139
140
141
        super().__init__()
        self.c_fc = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=config.use_bias,
142
            quant_config=quant_config,
143
            prefix=f"{prefix}.c_fc",
144
145
146
147
148
        )
        self.c_proj = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=config.use_bias,
149
            quant_config=quant_config,
150
            prefix=f"{prefix}.c_proj",
151
        )
152
        self.act = get_act_fn(config.hidden_act)
153
154
155
156
157
158
159
160
161
162
163
164

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class Starcoder2DecoderLayer(nn.Module):

    def __init__(self,
                 config: Starcoder2Config,
165
                 cache_config: Optional[CacheConfig] = None,
166
167
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
168
169
        super().__init__()
        self.hidden_size = config.hidden_size
170
171
        self.self_attn = Starcoder2Attention(config,
                                             cache_config,
172
173
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.self_attn")
174
175
176
        self.mlp = Starcoder2MLP(config,
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.mlp")
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.norm_epsilon)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.norm_epsilon)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


205
@support_torch_compile
206
207
class Starcoder2Model(nn.Module):

208
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
209
        super().__init__()
210
211
212
213
214

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

215
216
217
        self.config = config
        self.vocab_size = config.vocab_size

218
219
220
221
222
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens")
223
224
225
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Starcoder2DecoderLayer(
226
227
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
228
229
            prefix=f"{prefix}.layers",
        )
230
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
231
232
233
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
234

235
236
237
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

238
239
240
241
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
242
        intermediate_tensors: Optional[IntermediateTensors],
243
        inputs_embeds: Optional[torch.Tensor] = None,
244
245
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
246
247
248
249
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
250
251
252
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
253
254
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(positions, hidden_states)
255
256
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
257
258
259
        hidden_states = self.norm(hidden_states)
        return hidden_states

260
261
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
262
263
264
265
266
267
268
269
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
270
        loaded_params: set[str] = set()
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

295

296
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
297

298
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
299
        super().__init__()
300
301
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
302
        self.config = config
303
304
        self.model = Starcoder2Model(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "model"))
305
306
307
        self.vocab_size = config.vocab_size
        self.unpadded_vocab_size = config.vocab_size
        if config.tie_word_embeddings:
308
            self.lm_head = self.model.embed_tokens
309
310
311
312
313
314
315
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
316
                quant_config=quant_config,
317
                prefix=f"{prefix}.lm_head",
318
            )
319
320
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
321
322
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
323

324
325
326
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

327
328
329
330
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
331
        intermediate_tensors: Optional[IntermediateTensors] = None,
332
        inputs_embeds: Optional[torch.Tensor] = None,
333
    ) -> Union[torch.Tensor, IntermediateTensors]:
334
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
335
                                   inputs_embeds)
336
337
        return hidden_states

338
339
340
341
342
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
343
        logits = self.logits_processor(self.lm_head, hidden_states,
344
345
346
                                       sampling_metadata)
        return logits

347
348
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
349
350
351
352
        loader = AutoWeightsLoader(
            self,
            # Models trained using ColossalAI may include these tensors in
            # the checkpoint. Skip them.
353
354
            skip_prefixes=(["lm_head.weight"]
                           if self.config.tie_word_embeddings else None),
355
356
        )
        return loader.load_weights(weights)