starcoder2.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
22
23
"""PyTorch Starcoder2 model."""

24
from collections.abc import Iterable
25
from itertools import islice
26
27
28

import torch
from torch import nn
29
from transformers import Starcoder2Config
30

31
from vllm.attention import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.activation import get_act_fn
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
from vllm.model_executor.layers.rotary_embedding import get_rope
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
46
47
    ParallelLMHead,
    VocabParallelEmbedding,
)
48
from vllm.model_executor.model_loader.weight_utils import (
49
50
51
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
52
from vllm.sequence import IntermediateTensors
53

54
from .interfaces import SupportsPP
55
56
57
58
59
60
61
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
62

63
64

class Starcoder2Attention(nn.Module):
65
66
67
    def __init__(
        self,
        config: Starcoder2Config,
68
69
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
70
71
        prefix: str = "",
    ):
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        super().__init__()
        self.config = config

        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta
        self.max_position_embeddings = config.max_position_embeddings
        self.use_bias = config.use_bias

        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.use_bias,
104
            quant_config=quant_config,
105
            prefix=f"{prefix}.qkv_proj",
106
107
108
109
110
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=self.use_bias,
111
            quant_config=quant_config,
112
            prefix=f"{prefix}.o_proj",
113
114
115
116
117
118
119
120
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
121
122
123
124
125
126
127
128
129
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
130
131
132
133
134
135
136
137
138

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
139
        attn_output = self.attn(q, k, v)
140
141
142
143
144
        output, _ = self.o_proj(attn_output)
        return output


class Starcoder2MLP(nn.Module):
145
146
147
    def __init__(
        self,
        config: Starcoder2Config,
148
        quant_config: QuantizationConfig | None = None,
149
150
        prefix: str = "",
    ):
151
152
153
154
155
        super().__init__()
        self.c_fc = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=config.use_bias,
156
            quant_config=quant_config,
157
            prefix=f"{prefix}.c_fc",
158
159
160
161
162
        )
        self.c_proj = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=config.use_bias,
163
            quant_config=quant_config,
164
            prefix=f"{prefix}.c_proj",
165
        )
166
        self.act = get_act_fn(config.hidden_act)
167
168
169
170
171
172
173
174
175

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class Starcoder2DecoderLayer(nn.Module):
176
177
178
    def __init__(
        self,
        config: Starcoder2Config,
179
180
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
181
182
        prefix: str = "",
    ):
183
184
        super().__init__()
        self.hidden_size = config.hidden_size
185
186
187
188
189
190
191
192
193
194
195
196
197
        self.self_attn = Starcoder2Attention(
            config,
            cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.mlp = Starcoder2MLP(
            config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
        self.post_attention_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.norm_epsilon
        )
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


222
@support_torch_compile
223
class Starcoder2Model(nn.Module):
224
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
225
        super().__init__()
226
227
228
229
230

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

231
232
233
        self.config = config
        self.vocab_size = config.vocab_size

234
235
236
237
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
238
239
            prefix=f"{prefix}.embed_tokens",
        )
240
241
242
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Starcoder2DecoderLayer(
243
244
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
245
246
            prefix=f"{prefix}.layers",
        )
247
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
248
249
250
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
251

252
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
253
254
        return self.embed_tokens(input_ids)

255
256
257
258
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
259
260
261
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
262
        if get_pp_group().is_first_rank:
263
264
265
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
266
                hidden_states = self.embed_input_ids(input_ids)
267
268
269
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
270
        for layer in islice(self.layers, self.start_layer, self.end_layer):
271
            hidden_states = layer(positions, hidden_states)
272
273
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
274
275
276
        hidden_states = self.norm(hidden_states)
        return hidden_states

277
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
278
279
280
281
282
283
284
285
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
286
        loaded_params: set[str] = set()
287
        for name, loaded_weight in weights:
288
            for param_name, weight_name, shard_id in stacked_params_mapping:
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
305
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
306
307
308
309
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

310

311
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
312
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
313
        super().__init__()
314
315
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
316
        self.config = config
317
318
319
        self.model = Starcoder2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
320
        self.vocab_size = config.vocab_size
321

322
        if config.tie_word_embeddings:
323
            self.lm_head = self.model.embed_tokens
324
325
        else:
            self.lm_head = ParallelLMHead(
326
                config.vocab_size,
327
                config.hidden_size,
328
                quant_config=quant_config,
329
                prefix=f"{prefix}.lm_head",
330
            )
331
        self.logits_processor = LogitsProcessor(config.vocab_size)
332
        self.make_empty_intermediate_tensors = (
333
334
            self.model.make_empty_intermediate_tensors
        )
335

336
337
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
338

339
340
341
342
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
343
344
345
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
346
347
348
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
349
350
        return hidden_states

351
352
353
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
354
    ) -> torch.Tensor | None:
355
        logits = self.logits_processor(self.lm_head, hidden_states)
356
357
        return logits

358
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
359
360
361
362
        loader = AutoWeightsLoader(
            self,
            # Models trained using ColossalAI may include these tensors in
            # the checkpoint. Skip them.
363
364
365
            skip_prefixes=(
                ["lm_head.weight"] if self.config.tie_word_embeddings else None
            ),
366
367
        )
        return loader.load_weights(weights)