gpt_bigcode.py 12.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
23

24
from collections.abc import Iterable
25
from itertools import islice
26
from typing import Optional, Union
27
28
29
30
31

import torch
from torch import nn
from transformers import GPTBigCodeConfig

32
from vllm.attention import Attention
33
from vllm.compilation.decorators import support_torch_compile
34
from vllm.config import CacheConfig, VllmConfig
35
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
36
from vllm.model_executor.layers.activation import get_act_fn
37
38
39
40
41
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
46
47
    ParallelLMHead,
    VocabParallelEmbedding,
)
48
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
from vllm.sequence import IntermediateTensors
50

51
from .interfaces import SupportsLoRA, SupportsPP
52
53
54
55
56
57
58
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
59

60
61

class GPTBigCodeAttention(nn.Module):
62
63
64
    def __init__(
        self,
        config: GPTBigCodeConfig,
65
        cache_config: Optional[CacheConfig] = None,
66
        quant_config: Optional[QuantizationConfig] = None,
67
        prefix: str = "",
68
    ):
69
70
71
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
72
        self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
73
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
74
        self.num_heads = total_num_heads // self.tensor_model_parallel_world_size
75
        self.head_dim = self.hidden_size // total_num_heads
76
        self.scale = self.head_dim**-0.5
77

78
79
        self.multi_query = config.multi_query
        if self.multi_query:
80
            total_num_kv_heads = 1
81
82
            self.num_kv_heads = 1
        else:
83
            total_num_kv_heads = total_num_heads
84
            self.num_kv_heads = self.num_heads
85
86
87
88
89
90
91
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
92
            quant_config=quant_config,
93
            prefix=f"{prefix}.c_attn",
94
        )
95
96
97
98
99

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
100
            quant_config=quant_config,
101
            prefix=f"{prefix}.c_proj",
102
        )
103
104
105
106
107
108
109
110
111
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scale=self.scale,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
112
113
114
115
116

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
117
118
119
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
120
                self.hidden_size // self.tensor_model_parallel_world_size,
121
122
                self.kv_dim,
                self.kv_dim,
123
            ],
124
125
            dim=-1,
        )
126
        attn_output = self.attn(q, k, v)
127
128
129
130
131
132
133
134
135
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
136
        quant_config: Optional[QuantizationConfig] = None,
137
        prefix: str = "",
138
139
140
    ):
        super().__init__()
        hidden_size = config.hidden_size
141
142
143
144
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
145
            quant_config=quant_config,
146
            prefix=f"{prefix}.c_fc",
147
148
149
150
151
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
152
            quant_config=quant_config,
153
            prefix=f"{prefix}.c_proj",
154
        )
155
        self.act = get_act_fn(config.activation_function)
156
157
158
159
160
161
162
163
164

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):
165
166
167
    def __init__(
        self,
        config: GPTBigCodeConfig,
168
        cache_config: Optional[CacheConfig] = None,
169
        quant_config: Optional[QuantizationConfig] = None,
170
        prefix: str = "",
171
    ):
172
173
        super().__init__()
        hidden_size = config.hidden_size
174
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
175
176

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
177
178
179
        self.attn = GPTBigCodeAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
180
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
181
        self.mlp = GPTBigMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
182
183
184
185
186
187
188

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
189
190
191
        attn_output = self.attn(
            hidden_states=hidden_states,
        )
192
193
194
195
196
197
198
199
200
201
202
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


203
@support_torch_compile
204
class GPTBigCodeModel(nn.Module):
205
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
206
        super().__init__()
207
208
209
210
211
212

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

213
        self.config = config
214
        assert not config.add_cross_attention
215
216

        self.embed_dim = config.hidden_size
217
218
219
220
221
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
222
        self.vocab_size = config.vocab_size + lora_vocab
223
224
225
        self.wte = VocabParallelEmbedding(
            self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
        )
226
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
227
228
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
229
            lambda prefix: GPTBigCodeBlock(
230
231
                config, cache_config, quant_config, prefix=prefix
            ),
232
233
            prefix=f"{prefix}.h",
        )
234
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
235
236
237
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.n_embd
        )
238

239
240
241
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

242
243
244
245
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
246
        intermediate_tensors: Optional[IntermediateTensors],
247
        inputs_embeds: Optional[torch.Tensor] = None,
248
249
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
250
251
252
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            hidden_states = inputs_embeds + self.wpe(position_ids)
253
254
        else:
            hidden_states = intermediate_tensors["hidden_states"]
255

256
        for layer in islice(self.h, self.start_layer, self.end_layer):
257
            hidden_states = layer(hidden_states)
258

259
260
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
261
262
263
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

264
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
265
        params_dict = dict(self.named_parameters(remove_duplicate=False))
266
        loaded_params: set[str] = set()
267
268
269
270
271
272
273
274
        for name, loaded_weight in weights:
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
275
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
276
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
277
            if "c_attn.input_scale" in name:
278
279
280
                weight_loader(param, loaded_weight, "q")
                weight_loader(param, loaded_weight, "k")
                weight_loader(param, loaded_weight, "v")
281
282
283
284
285
            else:
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

286

287
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
288
289
    packed_modules_mapping = {"c_attn": ["c_attn"]}

290
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
291
        super().__init__()
292
293
294
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
295

296
        self.config = config
297
298
        self.lora_config = lora_config

299
        self.quant_config = quant_config
300
301
302
        self.transformer = GPTBigCodeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
303
304
305
306
307
308
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
309
                org_num_embeddings=self.config.vocab_size,
310
311
                prefix=maybe_prefix(prefix, "lm_head"),
            )
312
313
314
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
315
316
317
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size
        )
318
        self.make_empty_intermediate_tensors = (
319
320
            self.transformer.make_empty_intermediate_tensors
        )
321

322
323
324
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

325
326
327
328
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
329
        intermediate_tensors: Optional[IntermediateTensors] = None,
330
        inputs_embeds: Optional[torch.Tensor] = None,
331
    ) -> Union[torch.Tensor, IntermediateTensors]:
332
333
334
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
335
336
        return hidden_states

337
338
339
340
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
341
        logits = self.logits_processor(self.lm_head, hidden_states)
342
343
        return logits

344
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
345
346
347
        skip_prefixes = None
        if self.config.tie_word_embeddings:
            skip_prefixes = ["lm_head."]
348
349
        loader = AutoWeightsLoader(
            self,
350
            skip_prefixes=skip_prefixes,
351
        )
352
        return loader.load_weights(weights)