gpt_bigcode.py 13.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
23
from collections.abc import Iterable
24
from itertools import islice
25
from typing import Optional, Union
26
27
28
29
30

import torch
from torch import nn
from transformers import GPTBigCodeConfig

31
from vllm.attention import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.activation import get_act_fn
36
37
38
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    ParallelLMHead, VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
from vllm.sequence import IntermediateTensors
46

47
from .interfaces import SupportsLoRA, SupportsPP
48
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
49
50
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
51

52
53
54

class GPTBigCodeAttention(nn.Module):

55
56
57
    def __init__(
        self,
        config: GPTBigCodeConfig,
58
        cache_config: Optional[CacheConfig] = None,
59
        quant_config: Optional[QuantizationConfig] = None,
60
        prefix: str = "",
61
    ):
62
63
64
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
65
        self.tensor_model_parallel_world_size = (
66
            get_tensor_model_parallel_world_size())
67
68
69
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
70
        self.head_dim = self.hidden_size // total_num_heads
71
        self.scale = self.head_dim**-0.5
72

73
74
        self.multi_query = config.multi_query
        if self.multi_query:
75
            total_num_kv_heads = 1
76
77
            self.num_kv_heads = 1
        else:
78
            total_num_kv_heads = total_num_heads
79
            self.num_kv_heads = self.num_heads
80
81
82
83
84
85
86
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
87
            quant_config=quant_config,
88
            prefix=f"{prefix}.c_attn",
89
        )
90
91
92
93
94

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
95
            quant_config=quant_config,
96
            prefix=f"{prefix}.c_proj",
97
        )
98
99
100
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
101
                              num_kv_heads=self.num_kv_heads,
102
                              cache_config=cache_config,
103
104
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
105
106
107
108
109

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
110
111
112
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
113
114
115
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
116
117
            dim=-1,
        )
118
        attn_output = self.attn(q, k, v)
119
120
121
122
123
124
125
126
127
128
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
129
        quant_config: Optional[QuantizationConfig] = None,
130
        prefix: str = "",
131
132
133
    ):
        super().__init__()
        hidden_size = config.hidden_size
134
135
136
137
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
138
            quant_config=quant_config,
139
            prefix=f"{prefix}.c_fc",
140
141
142
143
144
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
145
            quant_config=quant_config,
146
            prefix=f"{prefix}.c_proj",
147
        )
148
        self.act = get_act_fn(config.activation_function)
149
150
151
152
153
154
155
156
157
158

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

159
160
161
    def __init__(
        self,
        config: GPTBigCodeConfig,
162
        cache_config: Optional[CacheConfig] = None,
163
        quant_config: Optional[QuantizationConfig] = None,
164
        prefix: str = "",
165
    ):
166
167
        super().__init__()
        hidden_size = config.hidden_size
168
169
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
170
171

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
172
173
174
175
        self.attn = GPTBigCodeAttention(config,
                                        cache_config,
                                        quant_config,
                                        prefix=f"{prefix}.attn")
176
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
177
178
179
180
        self.mlp = GPTBigMLP(inner_dim,
                             config,
                             quant_config,
                             prefix=f"{prefix}.mlp")
181
182
183
184
185
186
187

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
188
        attn_output = self.attn(hidden_states=hidden_states, )
189
190
191
192
193
194
195
196
197
198
199
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


200
@support_torch_compile
201
202
class GPTBigCodeModel(nn.Module):

203
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
204
        super().__init__()
205
206
207
208
209
210

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

211
        self.config = config
212
        assert not config.add_cross_attention
213
214

        self.embed_dim = config.hidden_size
215
216
217
218
219
220
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
221
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
222
223
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
224
225
            lambda prefix: GPTBigCodeBlock(
                config, cache_config, quant_config, prefix=prefix),
226
227
            prefix=f"{prefix}.h",
        )
228
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
229
230
231
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
232

233
234
235
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

236
237
238
239
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
240
        intermediate_tensors: Optional[IntermediateTensors],
241
        inputs_embeds: Optional[torch.Tensor] = None,
242
243
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
244
245
246
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            hidden_states = inputs_embeds + self.wpe(position_ids)
247
248
        else:
            hidden_states = intermediate_tensors["hidden_states"]
249

250
        for layer in islice(self.h, self.start_layer, self.end_layer):
251
            hidden_states = layer(hidden_states)
252

253
254
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
255
256
257
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

258
259
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
260
        params_dict = dict(self.named_parameters(remove_duplicate=False))
261
        loaded_params: set[str] = set()
262
263
264
265
266
267
268
269
270
271
272
        for name, loaded_weight in weights:
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
273
            if "c_attn.input_scale" in name:
274
275
276
277
278
279
280
281
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

282

283
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
284
285
    packed_modules_mapping = {"c_attn": ["c_attn"]}

286
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
287
        super().__init__()
288
289
290
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
291

292
        self.config = config
293
294
        self.lora_config = lora_config

295
        self.quant_config = quant_config
296
        self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
297
298
                                           prefix=maybe_prefix(
                                               prefix, "transformer"))
299
300
301
302
303
304
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
305
306
                org_num_embeddings=self.config.vocab_size,
                prefix=maybe_prefix(prefix, "lm_head"))
307
308
309
310
311
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
312
313
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
314

315
316
317
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

318
319
320
321
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
322
        intermediate_tensors: Optional[IntermediateTensors] = None,
323
        inputs_embeds: Optional[torch.Tensor] = None,
324
    ) -> Union[torch.Tensor, IntermediateTensors]:
325
326
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
327
328
        return hidden_states

329
330
331
332
333
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
334
        logits = self.logits_processor(self.lm_head, hidden_states,
335
336
337
                                       sampling_metadata)
        return logits

338
339
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
340
341
342
        skip_prefixes = None
        if self.config.tie_word_embeddings:
            skip_prefixes = ["lm_head."]
343
344
        loader = AutoWeightsLoader(
            self,
345
            skip_prefixes=skip_prefixes,
346
        )
347
        return loader.load_weights(weights)