gpt_bigcode.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
23

24
from collections.abc import Iterable
25
from itertools import islice
26
27
28
29
30

import torch
from torch import nn
from transformers import GPTBigCodeConfig

31
from vllm.attention.layer import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.activation import get_act_fn
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
44
45
46
    ParallelLMHead,
    VocabParallelEmbedding,
)
47
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
48
from vllm.sequence import IntermediateTensors
49

50
from .interfaces import SupportsLoRA, SupportsPP
51
52
53
54
55
56
57
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
58

59
60

class GPTBigCodeAttention(nn.Module):
61
62
63
    def __init__(
        self,
        config: GPTBigCodeConfig,
64
65
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
66
        prefix: str = "",
67
    ):
68
69
70
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
71
        self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
72
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
73
        self.num_heads = total_num_heads // self.tensor_model_parallel_world_size
74
        self.head_dim = self.hidden_size // total_num_heads
75
        self.scale = self.head_dim**-0.5
76

77
78
        self.multi_query = config.multi_query
        if self.multi_query:
79
            total_num_kv_heads = 1
80
81
            self.num_kv_heads = 1
        else:
82
            total_num_kv_heads = total_num_heads
83
            self.num_kv_heads = self.num_heads
84
85
86
87
88
89
90
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
91
            quant_config=quant_config,
92
            prefix=f"{prefix}.c_attn",
93
        )
94
95
96
97
98

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
99
            quant_config=quant_config,
100
            prefix=f"{prefix}.c_proj",
101
        )
102
103
104
105
106
107
108
109
110
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scale=self.scale,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
111
112
113
114
115

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
116
117
118
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
119
                self.hidden_size // self.tensor_model_parallel_world_size,
120
121
                self.kv_dim,
                self.kv_dim,
122
            ],
123
124
            dim=-1,
        )
125
        attn_output = self.attn(q, k, v)
126
127
128
129
130
131
132
133
134
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
135
        quant_config: QuantizationConfig | None = None,
136
        prefix: str = "",
137
138
139
    ):
        super().__init__()
        hidden_size = config.hidden_size
140
141
142
143
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
144
            quant_config=quant_config,
145
            prefix=f"{prefix}.c_fc",
146
147
148
149
150
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
151
            quant_config=quant_config,
152
            prefix=f"{prefix}.c_proj",
153
        )
154
        self.act = get_act_fn(config.activation_function)
155
156
157
158
159
160
161
162
163

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):
164
165
166
    def __init__(
        self,
        config: GPTBigCodeConfig,
167
168
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
169
        prefix: str = "",
170
    ):
171
172
        super().__init__()
        hidden_size = config.hidden_size
173
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
174
175

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
176
177
178
        self.attn = GPTBigCodeAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
179
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
180
        self.mlp = GPTBigMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
181
182
183
184
185
186
187

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
188
189
190
        attn_output = self.attn(
            hidden_states=hidden_states,
        )
191
192
193
194
195
196
197
198
199
200
201
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


202
@support_torch_compile
203
class GPTBigCodeModel(nn.Module):
204
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
205
        super().__init__()
206
207
208
209
210

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

211
        self.config = config
212
        assert not config.add_cross_attention
213
214

        self.embed_dim = config.hidden_size
215
216

        self.vocab_size = config.vocab_size
217
218
219
        self.wte = VocabParallelEmbedding(
            self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
        )
220
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
221
222
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
223
            lambda prefix: GPTBigCodeBlock(
224
225
                config, cache_config, quant_config, prefix=prefix
            ),
226
227
            prefix=f"{prefix}.h",
        )
228
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
229
230
231
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.n_embd
        )
232

233
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
234
235
        return self.wte(input_ids)

236
237
    def forward(
        self,
238
        input_ids: torch.Tensor | None,
239
        position_ids: torch.Tensor,
240
241
242
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
243
        if get_pp_group().is_first_rank:
244
            if inputs_embeds is None:
245
                inputs_embeds = self.embed_input_ids(input_ids)
246
            hidden_states = inputs_embeds + self.wpe(position_ids)
247
248
        else:
            hidden_states = intermediate_tensors["hidden_states"]
249

250
        for layer in islice(self.h, self.start_layer, self.end_layer):
251
            hidden_states = layer(hidden_states)
252

253
254
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
255
256
257
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

258
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
259
        params_dict = dict(self.named_parameters(remove_duplicate=False))
260
        loaded_params: set[str] = set()
261
262
263
264
265
266
267
268
        for name, loaded_weight in weights:
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
269
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
270
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
271
            if "c_attn.input_scale" in name:
272
273
274
                weight_loader(param, loaded_weight, "q")
                weight_loader(param, loaded_weight, "k")
                weight_loader(param, loaded_weight, "v")
275
276
277
278
279
            else:
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

280

281
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
282
283
    packed_modules_mapping = {"c_attn": ["c_attn"]}

284
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
285
        super().__init__()
286
287
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
288

289
        self.config = config
290

291
        self.quant_config = quant_config
292
293
294
        self.transformer = GPTBigCodeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
295
296
297
298
299
300
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
301
302
                prefix=maybe_prefix(prefix, "lm_head"),
            )
303
304

        self.logits_processor = LogitsProcessor(config.vocab_size)
305
        self.make_empty_intermediate_tensors = (
306
307
            self.transformer.make_empty_intermediate_tensors
        )
308

309
310
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.embed_input_ids(input_ids)
311

312
313
    def forward(
        self,
314
        input_ids: torch.Tensor | None,
315
        positions: torch.Tensor,
316
317
318
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
319
320
321
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
322
323
        return hidden_states

324
325
326
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
327
    ) -> torch.Tensor | None:
328
        logits = self.logits_processor(self.lm_head, hidden_states)
329
330
        return logits

331
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
332
333
334
        skip_prefixes = None
        if self.config.tie_word_embeddings:
            skip_prefixes = ["lm_head."]
335
336
        loader = AutoWeightsLoader(
            self,
337
            skip_prefixes=skip_prefixes,
338
        )
339
        return loader.load_weights(weights)