gpt_bigcode.py 11.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Tuple
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTBigCodeConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.config import CacheConfig, LoRAConfig
29
from vllm.distributed import get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
37
from vllm.model_executor.layers.sampler import Sampler
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors, SamplerOutput
43

44
45
from .interfaces import SupportsLoRA

46
47
48

class GPTBigCodeAttention(nn.Module):

49
50
51
    def __init__(
        self,
        config: GPTBigCodeConfig,
52
        cache_config: Optional[CacheConfig] = None,
53
        quant_config: Optional[QuantizationConfig] = None,
54
    ):
55
56
57
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
58
        self.tensor_model_parallel_world_size = (
59
            get_tensor_model_parallel_world_size())
60
61
62
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
63
        self.head_dim = self.hidden_size // total_num_heads
64
        self.scale = self.head_dim**-0.5
65

66
67
        self.multi_query = config.multi_query
        if self.multi_query:
68
            total_num_kv_heads = 1
69
70
            self.num_kv_heads = 1
        else:
71
            total_num_kv_heads = total_num_heads
72
            self.num_kv_heads = self.num_heads
73
74
75
76
77
78
79
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
80
            quant_config=quant_config,
81
        )
82
83
84
85
86

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
87
            quant_config=quant_config,
88
        )
89
90
91
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
92
                              num_kv_heads=self.num_kv_heads,
93
94
                              cache_config=cache_config,
                              quant_config=quant_config)
95
96
97
98

    def forward(
        self,
        hidden_states: torch.Tensor,
99
100
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
101
    ) -> torch.Tensor:
102
103
104
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
105
106
107
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
108
109
            dim=-1,
        )
110
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
111
112
113
114
115
116
117
118
119
120
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
121
        quant_config: Optional[QuantizationConfig] = None,
122
123
124
    ):
        super().__init__()
        hidden_size = config.hidden_size
125
126
127
128
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
129
            quant_config=quant_config,
130
131
132
133
134
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
135
            quant_config=quant_config,
136
        )
137
138
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
139
140
141
142
143
144
145
146
147
148

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

149
150
151
    def __init__(
        self,
        config: GPTBigCodeConfig,
152
        cache_config: Optional[CacheConfig] = None,
153
        quant_config: Optional[QuantizationConfig] = None,
154
    ):
155
156
        super().__init__()
        hidden_size = config.hidden_size
157
158
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
159
160

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
161
        self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
162
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
163
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
164
165
166
167

    def forward(
        self,
        hidden_states: torch.Tensor,
168
169
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
170
171
172
173
174
175
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
176
            attn_metadata=attn_metadata,
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

191
192
193
    def __init__(
        self,
        config: GPTBigCodeConfig,
194
        cache_config: Optional[CacheConfig] = None,
195
        quant_config: Optional[QuantizationConfig] = None,
196
        lora_config: Optional[LoRAConfig] = None,
197
    ):
198
199
        super().__init__()
        self.config = config
200
        assert not config.add_cross_attention
201
202

        self.embed_dim = config.hidden_size
203
204
205
206
207
208
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
209
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
210
        self.h = nn.ModuleList([
211
            GPTBigCodeBlock(config, cache_config, quant_config)
212
213
            for _ in range(config.num_hidden_layers)
        ])
214
215
216
217
218
219
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
220
221
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
222
223
224
225
226
227
228
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
229
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
230
231
232
233
234

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


235
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
236
237
    packed_modules_mapping = {"c_attn": ["c_attn"]}

238
    supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
239
240
241
242
243
244
245

    embedding_modules = {
        "wte": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    embedding_padding_modules = []
246

247
248
249
    def __init__(
        self,
        config: GPTBigCodeConfig,
250
        cache_config: Optional[CacheConfig] = None,
251
        quant_config: Optional[QuantizationConfig] = None,
252
        lora_config: Optional[LoRAConfig] = None,
253
    ):
254
        super().__init__()
255

256
        self.config = config
257
258
        self.lora_config = lora_config

259
        self.quant_config = quant_config
260
261
        self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
                                           lora_config)
262
263
264
265
266
267
268
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
                org_num_embeddings=self.config.vocab_size)
269
270
271
272
273
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
274
        self.sampler = Sampler()
275
276
277
278
279

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
280
281
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
282
        intermediate_tensors: Optional[IntermediateTensors] = None,
283
    ) -> torch.Tensor:
284
        hidden_states = self.transformer(input_ids, positions, kv_caches,
285
                                         attn_metadata)
286
287
        return hidden_states

288
289
290
291
292
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
293
        logits = self.logits_processor(self.lm_head, hidden_states,
294
295
296
                                       sampling_metadata)
        return logits

297
298
    def sample(
        self,
299
        logits: torch.Tensor,
300
        sampling_metadata: SamplingMetadata,
301
    ) -> Optional[SamplerOutput]:
302
        next_tokens = self.sampler(logits, sampling_metadata)
303
304
        return next_tokens

305
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
306
        params_dict = dict(self.named_parameters(remove_duplicate=False))
307
        for name, loaded_weight in weights:
308
309
310
311
312
313
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
314
315
316
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
317
318
319
320
321
322
323
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)