gpt_bigcode.py 13.6 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Set, Tuple, Union
21
22
23
24
25

import torch
from torch import nn
from transformers import GPTBigCodeConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.compilation.decorators import support_torch_compile
28
from vllm.config import CacheConfig, VllmConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
36
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
42

43
44
45
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
46

47
48
49

class GPTBigCodeAttention(nn.Module):

50
51
52
    def __init__(
        self,
        config: GPTBigCodeConfig,
53
        cache_config: Optional[CacheConfig] = None,
54
        quant_config: Optional[QuantizationConfig] = None,
55
        prefix: str = "",
56
    ):
57
58
59
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
60
        self.tensor_model_parallel_world_size = (
61
            get_tensor_model_parallel_world_size())
62
63
64
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
65
        self.head_dim = self.hidden_size // total_num_heads
66
        self.scale = self.head_dim**-0.5
67

68
69
        self.multi_query = config.multi_query
        if self.multi_query:
70
            total_num_kv_heads = 1
71
72
            self.num_kv_heads = 1
        else:
73
            total_num_kv_heads = total_num_heads
74
            self.num_kv_heads = self.num_heads
75
76
77
78
79
80
81
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
82
            quant_config=quant_config,
83
        )
84
85
86
87
88

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
89
            quant_config=quant_config,
90
        )
91
92
93
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
94
                              num_kv_heads=self.num_kv_heads,
95
                              cache_config=cache_config,
96
97
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
98
99
100
101

    def forward(
        self,
        hidden_states: torch.Tensor,
102
103
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
104
    ) -> torch.Tensor:
105
106
107
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
108
109
110
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
111
112
            dim=-1,
        )
113
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
114
115
116
117
118
119
120
121
122
123
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
124
        quant_config: Optional[QuantizationConfig] = None,
125
126
127
    ):
        super().__init__()
        hidden_size = config.hidden_size
128
129
130
131
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
132
            quant_config=quant_config,
133
134
135
136
137
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
138
            quant_config=quant_config,
139
        )
140
        self.act = get_act_fn(config.activation_function)
141
142
143
144
145
146
147
148
149
150

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

151
152
153
    def __init__(
        self,
        config: GPTBigCodeConfig,
154
        cache_config: Optional[CacheConfig] = None,
155
        quant_config: Optional[QuantizationConfig] = None,
156
        prefix: str = "",
157
    ):
158
159
        super().__init__()
        hidden_size = config.hidden_size
160
161
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
162
163

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
164
165
166
167
        self.attn = GPTBigCodeAttention(config,
                                        cache_config,
                                        quant_config,
                                        prefix=f"{prefix}.attn")
168
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
169
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
170
171
172
173

    def forward(
        self,
        hidden_states: torch.Tensor,
174
175
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
176
177
178
179
180
181
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
182
            attn_metadata=attn_metadata,
183
184
185
186
187
188
189
190
191
192
193
194
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


195
@support_torch_compile
196
197
class GPTBigCodeModel(nn.Module):

198
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
199
        super().__init__()
200
201
202
203
204
205

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

206
        self.config = config
207
        assert not config.add_cross_attention
208
209

        self.embed_dim = config.hidden_size
210
211
212
213
214
215
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
216
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
217
218
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
219
220
            lambda prefix: GPTBigCodeBlock(
                config, cache_config, quant_config, prefix=prefix),
221
222
            prefix=f"{prefix}.h",
        )
223
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
224
225
226
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
227

228
229
230
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

231
232
233
234
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
235
236
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
237
        intermediate_tensors: Optional[IntermediateTensors],
238
        inputs_embeds: Optional[torch.Tensor] = None,
239
240
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
241
242
243
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            hidden_states = inputs_embeds + self.wpe(position_ids)
244
245
        else:
            hidden_states = intermediate_tensors["hidden_states"]
246

247
        for i in range(self.start_layer, self.end_layer):
248
            layer = self.h[i]
249
250
251
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
252

253
254
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
255
256
257
258
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


259
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
260
261
    packed_modules_mapping = {"c_attn": ["c_attn"]}

262
    supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
263
264
265
266
267
268
269

    embedding_modules = {
        "wte": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    embedding_padding_modules = []
270

271
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
272
        super().__init__()
273
274
275
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
276

277
        self.config = config
278
279
        self.lora_config = lora_config

280
        self.quant_config = quant_config
281
282
        self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
                                           prefix=prefix)
283
284
285
286
287
288
289
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
                org_num_embeddings=self.config.vocab_size)
290
291
292
293
294
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
295
        self.sampler = get_sampler()
296
297
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
298

299
300
301
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

302
303
304
305
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
306
307
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
308
        intermediate_tensors: Optional[IntermediateTensors] = None,
309
        inputs_embeds: Optional[torch.Tensor] = None,
310
    ) -> Union[torch.Tensor, IntermediateTensors]:
311
        hidden_states = self.transformer(input_ids, positions, kv_caches,
312
313
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
314
315
        return hidden_states

316
317
318
319
320
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
321
        logits = self.logits_processor(self.lm_head, hidden_states,
322
323
324
                                       sampling_metadata)
        return logits

325
326
    def sample(
        self,
327
        logits: torch.Tensor,
328
        sampling_metadata: SamplingMetadata,
329
    ) -> Optional[SamplerOutput]:
330
        next_tokens = self.sampler(logits, sampling_metadata)
331
332
        return next_tokens

333
334
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
335
        params_dict = dict(self.named_parameters(remove_duplicate=False))
336
        loaded_params: Set[str] = set()
337
        for name, loaded_weight in weights:
338
339
340
341
342
343
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
344
345
            if is_pp_missing_parameter(name, self):
                continue
346
347
348
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
349
350
351
352
353
354
355
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)
356
357
            loaded_params.add(name)
        return loaded_params