gpt_bigcode.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
23
24
from collections.abc import Iterable
from typing import Optional, Union
25
26
27
28
29

import torch
from torch import nn
from transformers import GPTBigCodeConfig

30
from vllm.attention import Attention
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
from vllm.model_executor.layers.activation import get_act_fn
35
36
37
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors
45

46
from .interfaces import SupportsLoRA, SupportsPP
47
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
48
                    make_empty_intermediate_tensors_factory, make_layers)
49

50
51
52

class GPTBigCodeAttention(nn.Module):

53
54
55
    def __init__(
        self,
        config: GPTBigCodeConfig,
56
        cache_config: Optional[CacheConfig] = None,
57
        quant_config: Optional[QuantizationConfig] = None,
58
        prefix: str = "",
59
    ):
60
61
62
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
63
        self.tensor_model_parallel_world_size = (
64
            get_tensor_model_parallel_world_size())
65
66
67
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
68
        self.head_dim = self.hidden_size // total_num_heads
69
        self.scale = self.head_dim**-0.5
70

71
72
        self.multi_query = config.multi_query
        if self.multi_query:
73
            total_num_kv_heads = 1
74
75
            self.num_kv_heads = 1
        else:
76
            total_num_kv_heads = total_num_heads
77
            self.num_kv_heads = self.num_heads
78
79
80
81
82
83
84
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
85
            quant_config=quant_config,
86
        )
87
88
89
90
91

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
92
            quant_config=quant_config,
93
        )
94
95
96
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
97
                              num_kv_heads=self.num_kv_heads,
98
                              cache_config=cache_config,
99
100
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
101
102
103
104
105

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
106
107
108
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
109
110
111
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
112
113
            dim=-1,
        )
114
        attn_output = self.attn(q, k, v)
115
116
117
118
119
120
121
122
123
124
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
125
        quant_config: Optional[QuantizationConfig] = None,
126
127
128
    ):
        super().__init__()
        hidden_size = config.hidden_size
129
130
131
132
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
133
            quant_config=quant_config,
134
135
136
137
138
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
139
            quant_config=quant_config,
140
        )
141
        self.act = get_act_fn(config.activation_function)
142
143
144
145
146
147
148
149
150
151

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

152
153
154
    def __init__(
        self,
        config: GPTBigCodeConfig,
155
        cache_config: Optional[CacheConfig] = None,
156
        quant_config: Optional[QuantizationConfig] = None,
157
        prefix: str = "",
158
    ):
159
160
        super().__init__()
        hidden_size = config.hidden_size
161
162
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
163
164

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
165
166
167
168
        self.attn = GPTBigCodeAttention(config,
                                        cache_config,
                                        quant_config,
                                        prefix=f"{prefix}.attn")
169
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
170
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
171
172
173
174
175
176
177

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
178
        attn_output = self.attn(hidden_states=hidden_states, )
179
180
181
182
183
184
185
186
187
188
189
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


190
@support_torch_compile
191
192
class GPTBigCodeModel(nn.Module):

193
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
194
        super().__init__()
195
196
197
198
199
200

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

201
        self.config = config
202
        assert not config.add_cross_attention
203
204

        self.embed_dim = config.hidden_size
205
206
207
208
209
210
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
211
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
212
213
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
214
215
            lambda prefix: GPTBigCodeBlock(
                config, cache_config, quant_config, prefix=prefix),
216
217
            prefix=f"{prefix}.h",
        )
218
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
219
220
221
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
222

223
224
225
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

226
227
228
229
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
230
        intermediate_tensors: Optional[IntermediateTensors],
231
        inputs_embeds: Optional[torch.Tensor] = None,
232
233
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
234
235
236
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            hidden_states = inputs_embeds + self.wpe(position_ids)
237
238
        else:
            hidden_states = intermediate_tensors["hidden_states"]
239

240
241
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states)
242

243
244
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
245
246
247
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

248
249
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
250
        params_dict = dict(self.named_parameters(remove_duplicate=False))
251
        loaded_params: set[str] = set()
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        for name, loaded_weight in weights:
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

272

273
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
274
275
    packed_modules_mapping = {"c_attn": ["c_attn"]}

276
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
277
        super().__init__()
278
279
280
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
281

282
        self.config = config
283
284
        self.lora_config = lora_config

285
        self.quant_config = quant_config
286
287
        self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
                                           prefix=prefix)
288
289
290
291
292
293
294
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
                org_num_embeddings=self.config.vocab_size)
295
296
297
298
299
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
300
301
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
302

303
304
305
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

306
307
308
309
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
310
        intermediate_tensors: Optional[IntermediateTensors] = None,
311
        inputs_embeds: Optional[torch.Tensor] = None,
312
    ) -> Union[torch.Tensor, IntermediateTensors]:
313
314
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
315
316
        return hidden_states

317
318
319
320
321
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
322
        logits = self.logits_processor(self.lm_head, hidden_states,
323
324
325
                                       sampling_metadata)
        return logits

326
327
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
328
329
330
        skip_prefixes = None
        if self.config.tie_word_embeddings:
            skip_prefixes = ["lm_head."]
331
332
        loader = AutoWeightsLoader(
            self,
333
            skip_prefixes=skip_prefixes,
334
        )
335
        return loader.load_weights(weights)