gpt_bigcode.py 12.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
21
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
22
from typing import Iterable, Optional, Set, Tuple, Union
23
24
25
26
27

import torch
from torch import nn
from transformers import GPTBigCodeConfig

28
from vllm.attention import Attention
29
from vllm.compilation.decorators import support_torch_compile
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
from vllm.model_executor.layers.activation import get_act_fn
33
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
43

44
from .interfaces import SupportsLoRA, SupportsPP
45
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
46
                    make_empty_intermediate_tensors_factory, make_layers)
47

48
49
50

class GPTBigCodeAttention(nn.Module):

51
52
53
    def __init__(
        self,
        config: GPTBigCodeConfig,
54
        cache_config: Optional[CacheConfig] = None,
55
        quant_config: Optional[QuantizationConfig] = None,
56
        prefix: str = "",
57
    ):
58
59
60
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
61
        self.tensor_model_parallel_world_size = (
62
            get_tensor_model_parallel_world_size())
63
64
65
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
66
        self.head_dim = self.hidden_size // total_num_heads
67
        self.scale = self.head_dim**-0.5
68

69
70
        self.multi_query = config.multi_query
        if self.multi_query:
71
            total_num_kv_heads = 1
72
73
            self.num_kv_heads = 1
        else:
74
            total_num_kv_heads = total_num_heads
75
            self.num_kv_heads = self.num_heads
76
77
78
79
80
81
82
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
83
            quant_config=quant_config,
84
        )
85
86
87
88
89

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
90
            quant_config=quant_config,
91
        )
92
93
94
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
95
                              num_kv_heads=self.num_kv_heads,
96
                              cache_config=cache_config,
97
98
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
99
100
101
102
103

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
104
105
106
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
107
108
109
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
110
111
            dim=-1,
        )
112
        attn_output = self.attn(q, k, v)
113
114
115
116
117
118
119
120
121
122
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
123
        quant_config: Optional[QuantizationConfig] = None,
124
125
126
    ):
        super().__init__()
        hidden_size = config.hidden_size
127
128
129
130
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
131
            quant_config=quant_config,
132
133
134
135
136
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
137
            quant_config=quant_config,
138
        )
139
        self.act = get_act_fn(config.activation_function)
140
141
142
143
144
145
146
147
148
149

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

150
151
152
    def __init__(
        self,
        config: GPTBigCodeConfig,
153
        cache_config: Optional[CacheConfig] = None,
154
        quant_config: Optional[QuantizationConfig] = None,
155
        prefix: str = "",
156
    ):
157
158
        super().__init__()
        hidden_size = config.hidden_size
159
160
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
161
162

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
163
164
165
166
        self.attn = GPTBigCodeAttention(config,
                                        cache_config,
                                        quant_config,
                                        prefix=f"{prefix}.attn")
167
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
168
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
169
170
171
172
173
174
175

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
176
        attn_output = self.attn(hidden_states=hidden_states, )
177
178
179
180
181
182
183
184
185
186
187
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


188
@support_torch_compile
189
190
class GPTBigCodeModel(nn.Module):

191
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
192
        super().__init__()
193
194
195
196
197
198

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

199
        self.config = config
200
        assert not config.add_cross_attention
201
202

        self.embed_dim = config.hidden_size
203
204
205
206
207
208
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
209
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
210
211
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
212
213
            lambda prefix: GPTBigCodeBlock(
                config, cache_config, quant_config, prefix=prefix),
214
215
            prefix=f"{prefix}.h",
        )
216
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
217
218
219
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
220

221
222
223
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

224
225
226
227
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
228
        intermediate_tensors: Optional[IntermediateTensors],
229
        inputs_embeds: Optional[torch.Tensor] = None,
230
231
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
232
233
234
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            hidden_states = inputs_embeds + self.wpe(position_ids)
235
236
        else:
            hidden_states = intermediate_tensors["hidden_states"]
237

238
239
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states)
240

241
242
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
243
244
245
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

270

271
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
272
273
    packed_modules_mapping = {"c_attn": ["c_attn"]}

274
    # LoRA specific attributes
275
276
277
278
279
    embedding_modules = {
        "wte": "input_embeddings",
        "lm_head": "output_embeddings",
    }

280
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
281
        super().__init__()
282
283
284
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
285

286
        self.config = config
287
288
        self.lora_config = lora_config

289
        self.quant_config = quant_config
290
291
        self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
                                           prefix=prefix)
292
293
294
295
296
297
298
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
                org_num_embeddings=self.config.vocab_size)
299
300
301
302
303
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
304
305
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
306

307
308
309
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

310
311
312
313
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
314
        intermediate_tensors: Optional[IntermediateTensors] = None,
315
        inputs_embeds: Optional[torch.Tensor] = None,
316
    ) -> Union[torch.Tensor, IntermediateTensors]:
317
318
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
319
320
        return hidden_states

321
322
323
324
325
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
326
        logits = self.logits_processor(self.lm_head, hidden_states,
327
328
329
                                       sampling_metadata)
        return logits

330
331
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
332
333
334
335
336
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]),
        )
        return loader.load_weights(weights)