"tests/vscode:/vscode.git/clone" did not exist on "e9056056fbacecbac4318bd0323745fdd7fe55b6"
gpt_bigcode.py 12.7 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Tuple, Union
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTBigCodeConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.config import CacheConfig, LoRAConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
42

43
44
45
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
46

47
48
49

class GPTBigCodeAttention(nn.Module):

50
51
52
    def __init__(
        self,
        config: GPTBigCodeConfig,
53
        cache_config: Optional[CacheConfig] = None,
54
        quant_config: Optional[QuantizationConfig] = None,
55
    ):
56
57
58
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
59
        self.tensor_model_parallel_world_size = (
60
            get_tensor_model_parallel_world_size())
61
62
63
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
64
        self.head_dim = self.hidden_size // total_num_heads
65
        self.scale = self.head_dim**-0.5
66

67
68
        self.multi_query = config.multi_query
        if self.multi_query:
69
            total_num_kv_heads = 1
70
71
            self.num_kv_heads = 1
        else:
72
            total_num_kv_heads = total_num_heads
73
            self.num_kv_heads = self.num_heads
74
75
76
77
78
79
80
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
81
            quant_config=quant_config,
82
        )
83
84
85
86
87

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
88
            quant_config=quant_config,
89
        )
90
91
92
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
93
                              num_kv_heads=self.num_kv_heads,
94
95
                              cache_config=cache_config,
                              quant_config=quant_config)
96
97
98
99

    def forward(
        self,
        hidden_states: torch.Tensor,
100
101
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
102
    ) -> torch.Tensor:
103
104
105
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
106
107
108
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
109
110
            dim=-1,
        )
111
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
112
113
114
115
116
117
118
119
120
121
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
122
        quant_config: Optional[QuantizationConfig] = None,
123
124
125
    ):
        super().__init__()
        hidden_size = config.hidden_size
126
127
128
129
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
130
            quant_config=quant_config,
131
132
133
134
135
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
136
            quant_config=quant_config,
137
        )
138
139
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
140
141
142
143
144
145
146
147
148
149

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

150
151
152
    def __init__(
        self,
        config: GPTBigCodeConfig,
153
        cache_config: Optional[CacheConfig] = None,
154
        quant_config: Optional[QuantizationConfig] = None,
155
    ):
156
157
        super().__init__()
        hidden_size = config.hidden_size
158
159
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
160
161

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
162
        self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
163
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
164
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
165
166
167
168

    def forward(
        self,
        hidden_states: torch.Tensor,
169
170
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
171
172
173
174
175
176
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
177
            attn_metadata=attn_metadata,
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

192
193
194
    def __init__(
        self,
        config: GPTBigCodeConfig,
195
        cache_config: Optional[CacheConfig] = None,
196
        quant_config: Optional[QuantizationConfig] = None,
197
        lora_config: Optional[LoRAConfig] = None,
198
        prefix: str = "",
199
    ):
200
201
        super().__init__()
        self.config = config
202
        assert not config.add_cross_attention
203
204

        self.embed_dim = config.hidden_size
205
206
207
208
209
210
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.wte = VocabParallelEmbedding(self.vocab_size,
                                          self.embed_dim,
                                          org_num_embeddings=config.vocab_size)
211
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
212
213
214
215
216
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.h",
        )
217
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
218
219
220
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
221
222
223
224
225

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
226
227
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
228
229
230
231
232
233
234
235
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            inputs_embeds = self.wte(input_ids)
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            hidden_states = intermediate_tensors["hidden_states"]
236

237
        for i in range(self.start_layer, self.end_layer):
238
            layer = self.h[i]
239
240
241
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
242

243
244
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
245
246
247
248
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


249
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
250
251
    packed_modules_mapping = {"c_attn": ["c_attn"]}

252
    supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
253
254
255
256
257
258
259

    embedding_modules = {
        "wte": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    embedding_padding_modules = []
260

261
262
263
    def __init__(
        self,
        config: GPTBigCodeConfig,
264
        cache_config: Optional[CacheConfig] = None,
265
        quant_config: Optional[QuantizationConfig] = None,
266
        lora_config: Optional[LoRAConfig] = None,
267
    ):
268
        super().__init__()
269

270
        self.config = config
271
272
        self.lora_config = lora_config

273
        self.quant_config = quant_config
274
275
        self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
                                           lora_config)
276
277
278
279
280
281
282
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(
                self.transformer.vocab_size,
                self.transformer.embed_dim,
                org_num_embeddings=self.config.vocab_size)
283
284
285
286
287
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
288
        self.sampler = Sampler()
289
290
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
291
292
293
294
295

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
296
297
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
298
        intermediate_tensors: Optional[IntermediateTensors] = None,
299
    ) -> Union[torch.Tensor, IntermediateTensors]:
300
        hidden_states = self.transformer(input_ids, positions, kv_caches,
301
                                         attn_metadata, intermediate_tensors)
302
303
        return hidden_states

304
305
306
307
308
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
309
        logits = self.logits_processor(self.lm_head, hidden_states,
310
311
312
                                       sampling_metadata)
        return logits

313
314
    def sample(
        self,
315
        logits: torch.Tensor,
316
        sampling_metadata: SamplingMetadata,
317
    ) -> Optional[SamplerOutput]:
318
        next_tokens = self.sampler(logits, sampling_metadata)
319
320
        return next_tokens

321
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
322
        params_dict = dict(self.named_parameters(remove_duplicate=False))
323
        for name, loaded_weight in weights:
324
325
326
327
328
329
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
330
331
            if is_pp_missing_parameter(name, self):
                continue
332
333
334
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
335
336
337
338
339
340
341
            # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
            if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
                weight_loader(param, loaded_weight, 'q')
                weight_loader(param, loaded_weight, 'k')
                weight_loader(param, loaded_weight, 'v')
            else:
                weight_loader(param, loaded_weight)