gpt_bigcode.py 11.2 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Tuple
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTBigCodeConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.config import CacheConfig, LoRAConfig
29
from vllm.distributed import get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
37
from vllm.model_executor.layers.sampler import Sampler
38
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import SamplerOutput
43
44
45
46


class GPTBigCodeAttention(nn.Module):

47
48
49
    def __init__(
        self,
        config: GPTBigCodeConfig,
50
        cache_config: Optional[CacheConfig] = None,
51
        quant_config: Optional[QuantizationConfig] = None,
52
    ):
53
54
55
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
56
        self.tensor_model_parallel_world_size = (
57
            get_tensor_model_parallel_world_size())
58
59
60
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
61
        self.head_dim = self.hidden_size // total_num_heads
62
        self.scale = self.head_dim**-0.5
63

64
65
        self.multi_query = config.multi_query
        if self.multi_query:
66
            total_num_kv_heads = 1
67
68
            self.num_kv_heads = 1
        else:
69
            total_num_kv_heads = total_num_heads
70
            self.num_kv_heads = self.num_heads
71
72
73
74
75
76
77
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
78
            quant_config=quant_config,
79
        )
80
81
82
83
84

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
85
            quant_config=quant_config,
86
        )
87
88
89
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
90
                              num_kv_heads=self.num_kv_heads,
91
92
                              cache_config=cache_config,
                              quant_config=quant_config)
93
94
95
96

    def forward(
        self,
        hidden_states: torch.Tensor,
97
98
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
99
    ) -> torch.Tensor:
100
101
102
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
103
104
105
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
106
107
            dim=-1,
        )
108
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
109
110
111
112
113
114
115
116
117
118
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
119
        quant_config: Optional[QuantizationConfig] = None,
120
121
122
    ):
        super().__init__()
        hidden_size = config.hidden_size
123
124
125
126
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
127
            quant_config=quant_config,
128
129
130
131
132
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
133
            quant_config=quant_config,
134
        )
135
136
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
137
138
139
140
141
142
143
144
145
146

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

147
148
149
    def __init__(
        self,
        config: GPTBigCodeConfig,
150
        cache_config: Optional[CacheConfig] = None,
151
        quant_config: Optional[QuantizationConfig] = None,
152
    ):
153
154
        super().__init__()
        hidden_size = config.hidden_size
155
156
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
157
158

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
159
        self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
160
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
161
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
162
163
164
165

    def forward(
        self,
        hidden_states: torch.Tensor,
166
167
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
168
169
170
171
172
173
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
174
            attn_metadata=attn_metadata,
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

189
190
191
    def __init__(
        self,
        config: GPTBigCodeConfig,
192
        cache_config: Optional[CacheConfig] = None,
193
        quant_config: Optional[QuantizationConfig] = None,
194
        lora_config: Optional[LoRAConfig] = None,
195
    ):
196
197
        super().__init__()
        self.config = config
198
        assert not config.add_cross_attention
199
200

        self.embed_dim = config.hidden_size
201
202
203
204
205
206
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
207
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
208
        self.h = nn.ModuleList([
209
            GPTBigCodeBlock(config, cache_config, quant_config)
210
211
            for _ in range(config.num_hidden_layers)
        ])
212
213
214
215
216
217
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
218
219
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
220
221
222
223
224
225
226
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
227
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
228
229
230
231
232
233

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTBigCodeForCausalLM(nn.Module):
234
235
236
237
238
239
240
241
242
243
    packed_modules_mapping = {"c_attn": ["c_attn"]}

    supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]

    embedding_modules = {
        "wte": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    embedding_padding_modules = []
244

245
246
247
    def __init__(
        self,
        config: GPTBigCodeConfig,
248
        cache_config: Optional[CacheConfig] = None,
249
        quant_config: Optional[QuantizationConfig] = None,
250
        lora_config: Optional[LoRAConfig] = None,
251
    ):
252
253
        super().__init__()
        self.config = config
254
        self.quant_config = quant_config
255
256
        self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
                                           lora_config)
257
        self.lm_head_weight = self.transformer.wte.weight
258
259
260
261
262
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
263
        self.sampler = Sampler()
264
265
266
267
268

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
269
270
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
271
    ) -> torch.Tensor:
272
        hidden_states = self.transformer(input_ids, positions, kv_caches,
273
                                         attn_metadata)
274
275
        return hidden_states

276
277
278
279
280
281
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

282
283
    def sample(
        self,
284
        logits: torch.Tensor,
285
        sampling_metadata: SamplingMetadata,
286
    ) -> Optional[SamplerOutput]:
287
        next_tokens = self.sampler(logits, sampling_metadata)
288
289
        return next_tokens

290
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
291
        params_dict = dict(self.named_parameters(remove_duplicate=False))
292
        for name, loaded_weight in weights:
293
294
295
296
297
298
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
299
300
301
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
302
303
304
305
306
307
308
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)