gpt_bigcode.py 12.8 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Tuple, Union
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTBigCodeConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, LoRAConfig
30
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
31
from vllm.model_executor.layers.activation import get_act_fn
32
33
34
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
43

44
45
46
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
47

48
49
50

class GPTBigCodeAttention(nn.Module):

51
52
53
    def __init__(
        self,
        config: GPTBigCodeConfig,
54
        cache_config: Optional[CacheConfig] = None,
55
        quant_config: Optional[QuantizationConfig] = None,
56
    ):
57
58
59
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
60
        self.tensor_model_parallel_world_size = (
61
            get_tensor_model_parallel_world_size())
62
63
64
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
65
        self.head_dim = self.hidden_size // total_num_heads
66
        self.scale = self.head_dim**-0.5
67

68
69
        self.multi_query = config.multi_query
        if self.multi_query:
70
            total_num_kv_heads = 1
71
72
            self.num_kv_heads = 1
        else:
73
            total_num_kv_heads = total_num_heads
74
            self.num_kv_heads = self.num_heads
75
76
77
78
79
80
81
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
82
            quant_config=quant_config,
83
        )
84
85
86
87
88

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
89
            quant_config=quant_config,
90
        )
91
92
93
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
94
                              num_kv_heads=self.num_kv_heads,
95
96
                              cache_config=cache_config,
                              quant_config=quant_config)
97
98
99
100

    def forward(
        self,
        hidden_states: torch.Tensor,
101
102
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
103
    ) -> torch.Tensor:
104
105
106
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
107
108
109
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
110
111
            dim=-1,
        )
112
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
113
114
115
116
117
118
119
120
121
122
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
123
        quant_config: Optional[QuantizationConfig] = None,
124
125
126
    ):
        super().__init__()
        hidden_size = config.hidden_size
127
128
129
130
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
131
            quant_config=quant_config,
132
133
134
135
136
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
137
            quant_config=quant_config,
138
        )
139
140
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
141
142
143
144
145
146
147
148
149
150

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

151
152
153
    def __init__(
        self,
        config: GPTBigCodeConfig,
154
        cache_config: Optional[CacheConfig] = None,
155
        quant_config: Optional[QuantizationConfig] = None,
156
    ):
157
158
        super().__init__()
        hidden_size = config.hidden_size
159
160
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
161
162

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
163
        self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
164
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
165
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
166
167
168
169

    def forward(
        self,
        hidden_states: torch.Tensor,
170
171
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
172
173
174
175
176
177
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
178
            attn_metadata=attn_metadata,
179
180
181
182
183
184
185
186
187
188
189
190
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


191
@support_torch_compile
192
193
class GPTBigCodeModel(nn.Module):

194
195
196
    def __init__(
        self,
        config: GPTBigCodeConfig,
197
        cache_config: Optional[CacheConfig] = None,
198
        quant_config: Optional[QuantizationConfig] = None,
199
        lora_config: Optional[LoRAConfig] = None,
200
        prefix: str = "",
201
    ):
202
203
        super().__init__()
        self.config = config
204
        assert not config.add_cross_attention
205
206

        self.embed_dim = config.hidden_size
207
208
209
210
211
212
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
213
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
214
215
216
217
218
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.h",
        )
219
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
220
221
222
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
223
224
225
226
227

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
228
229
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
230
231
232
233
234
235
236
237
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            inputs_embeds = self.wte(input_ids)
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            hidden_states = intermediate_tensors["hidden_states"]
238

239
        for i in range(self.start_layer, self.end_layer):
240
            layer = self.h[i]
241
242
243
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
244

245
246
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
247
248
249
250
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


251
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
252
253
    packed_modules_mapping = {"c_attn": ["c_attn"]}

254
    supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
255
256
257
258
259
260
261

    embedding_modules = {
        "wte": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    embedding_padding_modules = []
262

263
264
265
    def __init__(
        self,
        config: GPTBigCodeConfig,
266
        cache_config: Optional[CacheConfig] = None,
267
        quant_config: Optional[QuantizationConfig] = None,
268
        lora_config: Optional[LoRAConfig] = None,
269
    ):
270
        super().__init__()
271

272
        self.config = config
273
274
        self.lora_config = lora_config

275
        self.quant_config = quant_config
276
277
        self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
                                           lora_config)
278
279
280
281
282
283
284
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
                org_num_embeddings=self.config.vocab_size)
285
286
287
288
289
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
290
        self.sampler = Sampler()
291
292
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
293
294
295
296
297

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
298
299
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
300
        intermediate_tensors: Optional[IntermediateTensors] = None,
301
    ) -> Union[torch.Tensor, IntermediateTensors]:
302
        hidden_states = self.transformer(input_ids, positions, kv_caches,
303
                                         attn_metadata, intermediate_tensors)
304
305
        return hidden_states

306
307
308
309
310
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
311
        logits = self.logits_processor(self.lm_head, hidden_states,
312
313
314
                                       sampling_metadata)
        return logits

315
316
    def sample(
        self,
317
        logits: torch.Tensor,
318
        sampling_metadata: SamplingMetadata,
319
    ) -> Optional[SamplerOutput]:
320
        next_tokens = self.sampler(logits, sampling_metadata)
321
322
        return next_tokens

323
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
324
        params_dict = dict(self.named_parameters(remove_duplicate=False))
325
        for name, loaded_weight in weights:
326
327
328
329
330
331
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
332
333
            if is_pp_missing_parameter(name, self):
                continue
334
335
336
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
337
338
339
340
341
342
343
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)