gpt_bigcode.py 13.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
21
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
22
from typing import Iterable, List, Optional, Set, Tuple, Union
23
24
25
26
27

import torch
from torch import nn
from transformers import GPTBigCodeConfig

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.compilation.decorators import support_torch_compile
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
from vllm.model_executor.layers.activation import get_act_fn
33
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
38
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
44

45
46
47
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
48

49
50
51

class GPTBigCodeAttention(nn.Module):

52
53
54
    def __init__(
        self,
        config: GPTBigCodeConfig,
55
        cache_config: Optional[CacheConfig] = None,
56
        quant_config: Optional[QuantizationConfig] = None,
57
        prefix: str = "",
58
    ):
59
60
61
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
62
        self.tensor_model_parallel_world_size = (
63
            get_tensor_model_parallel_world_size())
64
65
66
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
67
        self.head_dim = self.hidden_size // total_num_heads
68
        self.scale = self.head_dim**-0.5
69

70
71
        self.multi_query = config.multi_query
        if self.multi_query:
72
            total_num_kv_heads = 1
73
74
            self.num_kv_heads = 1
        else:
75
            total_num_kv_heads = total_num_heads
76
            self.num_kv_heads = self.num_heads
77
78
79
80
81
82
83
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
84
            quant_config=quant_config,
85
        )
86
87
88
89
90

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
91
            quant_config=quant_config,
92
        )
93
94
95
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
96
                              num_kv_heads=self.num_kv_heads,
97
                              cache_config=cache_config,
98
99
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
100
101
102
103

    def forward(
        self,
        hidden_states: torch.Tensor,
104
105
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
106
    ) -> torch.Tensor:
107
108
109
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
110
111
112
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
113
114
            dim=-1,
        )
115
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
116
117
118
119
120
121
122
123
124
125
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
126
        quant_config: Optional[QuantizationConfig] = None,
127
128
129
    ):
        super().__init__()
        hidden_size = config.hidden_size
130
131
132
133
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
134
            quant_config=quant_config,
135
136
137
138
139
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
140
            quant_config=quant_config,
141
        )
142
        self.act = get_act_fn(config.activation_function)
143
144
145
146
147
148
149
150
151
152

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

153
154
155
    def __init__(
        self,
        config: GPTBigCodeConfig,
156
        cache_config: Optional[CacheConfig] = None,
157
        quant_config: Optional[QuantizationConfig] = None,
158
        prefix: str = "",
159
    ):
160
161
        super().__init__()
        hidden_size = config.hidden_size
162
163
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
164
165

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
166
167
168
169
        self.attn = GPTBigCodeAttention(config,
                                        cache_config,
                                        quant_config,
                                        prefix=f"{prefix}.attn")
170
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
171
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
172
173
174
175

    def forward(
        self,
        hidden_states: torch.Tensor,
176
177
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
178
179
180
181
182
183
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
184
            attn_metadata=attn_metadata,
185
186
187
188
189
190
191
192
193
194
195
196
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


197
@support_torch_compile
198
199
class GPTBigCodeModel(nn.Module):

200
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
201
        super().__init__()
202
203
204
205
206
207

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

208
        self.config = config
209
        assert not config.add_cross_attention
210
211

        self.embed_dim = config.hidden_size
212
213
214
215
216
217
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
218
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
219
220
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
221
222
            lambda prefix: GPTBigCodeBlock(
                config, cache_config, quant_config, prefix=prefix),
223
224
            prefix=f"{prefix}.h",
        )
225
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
226
227
228
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
229

230
231
232
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

233
234
235
236
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
237
238
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
239
        intermediate_tensors: Optional[IntermediateTensors],
240
        inputs_embeds: Optional[torch.Tensor] = None,
241
242
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
243
244
245
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            hidden_states = inputs_embeds + self.wpe(position_ids)
246
247
        else:
            hidden_states = intermediate_tensors["hidden_states"]
248

249
        for i in range(self.start_layer, self.end_layer):
250
            layer = self.h[i]
251
252
253
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
254

255
256
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
257
258
259
260
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


261
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
262
263
    packed_modules_mapping = {"c_attn": ["c_attn"]}

264
    supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
265
266
267
268
269
270
271

    embedding_modules = {
        "wte": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    embedding_padding_modules = []
272

273
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
274
        super().__init__()
275
276
277
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
278

279
        self.config = config
280
281
        self.lora_config = lora_config

282
        self.quant_config = quant_config
283
284
        self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
                                           prefix=prefix)
285
286
287
288
289
290
291
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
                org_num_embeddings=self.config.vocab_size)
292
293
294
295
296
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
297
        self.sampler = get_sampler()
298
299
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
300

301
302
303
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

304
305
306
307
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
308
309
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
310
        intermediate_tensors: Optional[IntermediateTensors] = None,
311
        inputs_embeds: Optional[torch.Tensor] = None,
312
    ) -> Union[torch.Tensor, IntermediateTensors]:
313
        hidden_states = self.transformer(input_ids, positions, kv_caches,
314
315
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
316
317
        return hidden_states

318
319
320
321
322
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
323
        logits = self.logits_processor(self.lm_head, hidden_states,
324
325
326
                                       sampling_metadata)
        return logits

327
328
    def sample(
        self,
329
        logits: torch.Tensor,
330
        sampling_metadata: SamplingMetadata,
331
    ) -> Optional[SamplerOutput]:
332
        next_tokens = self.sampler(logits, sampling_metadata)
333
334
        return next_tokens

335
336
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
337
        params_dict = dict(self.named_parameters(remove_duplicate=False))
338
        loaded_params: Set[str] = set()
339
        for name, loaded_weight in weights:
340
341
342
343
344
345
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
346
347
            if is_pp_missing_parameter(name, self):
                continue
348
349
350
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
351
352
353
354
355
356
357
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)
358
359
            loaded_params.add(name)
        return loaded_params