gpt_bigcode.py 12.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
21
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
22
23
from collections.abc import Iterable
from typing import Optional, Union
24
25
26
27
28

import torch
from torch import nn
from transformers import GPTBigCodeConfig

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.activation import get_act_fn
34
35
36
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
44

45
from .interfaces import SupportsLoRA, SupportsPP
46
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
47
                    make_empty_intermediate_tensors_factory, make_layers)
48

49
50
51

class GPTBigCodeAttention(nn.Module):

52
53
54
    def __init__(
        self,
        config: GPTBigCodeConfig,
55
        cache_config: Optional[CacheConfig] = None,
56
        quant_config: Optional[QuantizationConfig] = None,
57
        prefix: str = "",
58
    ):
59
60
61
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
62
        self.tensor_model_parallel_world_size = (
63
            get_tensor_model_parallel_world_size())
64
65
66
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
67
        self.head_dim = self.hidden_size // total_num_heads
68
        self.scale = self.head_dim**-0.5
69

70
71
        self.multi_query = config.multi_query
        if self.multi_query:
72
            total_num_kv_heads = 1
73
74
            self.num_kv_heads = 1
        else:
75
            total_num_kv_heads = total_num_heads
76
            self.num_kv_heads = self.num_heads
77
78
79
80
81
82
83
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
84
            quant_config=quant_config,
85
        )
86
87
88
89
90

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
91
            quant_config=quant_config,
92
        )
93
94
95
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
96
                              num_kv_heads=self.num_kv_heads,
97
                              cache_config=cache_config,
98
99
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
100
101
102
103
104

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
105
106
107
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
108
109
110
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
111
112
            dim=-1,
        )
113
        attn_output = self.attn(q, k, v)
114
115
116
117
118
119
120
121
122
123
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
124
        quant_config: Optional[QuantizationConfig] = None,
125
126
127
    ):
        super().__init__()
        hidden_size = config.hidden_size
128
129
130
131
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
132
            quant_config=quant_config,
133
134
135
136
137
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
138
            quant_config=quant_config,
139
        )
140
        self.act = get_act_fn(config.activation_function)
141
142
143
144
145
146
147
148
149
150

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

151
152
153
    def __init__(
        self,
        config: GPTBigCodeConfig,
154
        cache_config: Optional[CacheConfig] = None,
155
        quant_config: Optional[QuantizationConfig] = None,
156
        prefix: str = "",
157
    ):
158
159
        super().__init__()
        hidden_size = config.hidden_size
160
161
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
162
163

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
164
165
166
167
        self.attn = GPTBigCodeAttention(config,
                                        cache_config,
                                        quant_config,
                                        prefix=f"{prefix}.attn")
168
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
169
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
170
171
172
173
174
175
176

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
177
        attn_output = self.attn(hidden_states=hidden_states, )
178
179
180
181
182
183
184
185
186
187
188
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


189
@support_torch_compile
190
191
class GPTBigCodeModel(nn.Module):

192
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
193
        super().__init__()
194
195
196
197
198
199

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

200
        self.config = config
201
        assert not config.add_cross_attention
202
203

        self.embed_dim = config.hidden_size
204
205
206
207
208
209
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
210
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
211
212
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
213
214
            lambda prefix: GPTBigCodeBlock(
                config, cache_config, quant_config, prefix=prefix),
215
216
            prefix=f"{prefix}.h",
        )
217
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
218
219
220
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
221

222
223
224
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

225
226
227
228
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
229
        intermediate_tensors: Optional[IntermediateTensors],
230
        inputs_embeds: Optional[torch.Tensor] = None,
231
232
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
233
234
235
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            hidden_states = inputs_embeds + self.wpe(position_ids)
236
237
        else:
            hidden_states = intermediate_tensors["hidden_states"]
238

239
240
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states)
241

242
243
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
244
245
246
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

247
248
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
249
        params_dict = dict(self.named_parameters(remove_duplicate=False))
250
        loaded_params: set[str] = set()
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        for name, loaded_weight in weights:
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

271

272
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
273
274
    packed_modules_mapping = {"c_attn": ["c_attn"]}

275
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
276
        super().__init__()
277
278
279
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
280

281
        self.config = config
282
283
        self.lora_config = lora_config

284
        self.quant_config = quant_config
285
286
        self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
                                           prefix=prefix)
287
288
289
290
291
292
293
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
                org_num_embeddings=self.config.vocab_size)
294
295
296
297
298
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
299
300
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
301

302
303
304
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

305
306
307
308
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
309
        intermediate_tensors: Optional[IntermediateTensors] = None,
310
        inputs_embeds: Optional[torch.Tensor] = None,
311
    ) -> Union[torch.Tensor, IntermediateTensors]:
312
313
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
314
315
        return hidden_states

316
317
318
319
320
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
321
        logits = self.logits_processor(self.lm_head, hidden_states,
322
323
324
                                       sampling_metadata)
        return logits

325
326
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
327
328
329
        skip_prefixes = None
        if self.config.tie_word_embeddings:
            skip_prefixes = ["lm_head."]
330
331
        loader = AutoWeightsLoader(
            self,
332
            skip_prefixes=skip_prefixes,
333
        )
334
        return loader.load_weights(weights)