gpt_bigcode.py 9.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Tuple
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTBigCodeConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.distributed import get_tensor_model_parallel_world_size
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
35
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
36
from vllm.model_executor.layers.sampler import Sampler
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import SamplerOutput
42
43
44
45


class GPTBigCodeAttention(nn.Module):

46
47
48
    def __init__(
        self,
        config: GPTBigCodeConfig,
49
        quant_config: Optional[QuantizationConfig] = None,
50
    ):
51
52
53
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
54
        self.tensor_model_parallel_world_size = (
55
            get_tensor_model_parallel_world_size())
56
57
58
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
59
        self.head_dim = self.hidden_size // total_num_heads
60
        self.scale = self.head_dim**-0.5
61

62
63
        self.multi_query = config.multi_query
        if self.multi_query:
64
            total_num_kv_heads = 1
65
66
            self.num_kv_heads = 1
        else:
67
            total_num_kv_heads = total_num_heads
68
            self.num_kv_heads = self.num_heads
69
70
71
72
73
74
75
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
76
            quant_config=quant_config,
77
        )
78
79
80
81
82

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
83
            quant_config=quant_config,
84
        )
85
86
87
88
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              num_kv_heads=self.num_kv_heads)
89
90
91
92

    def forward(
        self,
        hidden_states: torch.Tensor,
93
94
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
95
    ) -> torch.Tensor:
96
97
98
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
99
100
101
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
102
103
            dim=-1,
        )
104
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
105
106
107
108
109
110
111
112
113
114
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
115
        quant_config: Optional[QuantizationConfig] = None,
116
117
118
    ):
        super().__init__()
        hidden_size = config.hidden_size
119
120
121
122
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
123
            quant_config=quant_config,
124
125
126
127
128
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
129
            quant_config=quant_config,
130
        )
131
        quant_config = getattr(quant_config, "quant_config", None)
132
133
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
134
135
136
137
138
139
140
141
142
143

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

144
145
146
    def __init__(
        self,
        config: GPTBigCodeConfig,
147
        quant_config: Optional[QuantizationConfig] = None,
148
    ):
149
150
        super().__init__()
        hidden_size = config.hidden_size
151
152
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
153
154

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
155
        self.attn = GPTBigCodeAttention(config, quant_config)
156
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
157
        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
158
159
160
161

    def forward(
        self,
        hidden_states: torch.Tensor,
162
163
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
164
165
166
167
168
169
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
170
            attn_metadata=attn_metadata,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

185
186
187
    def __init__(
        self,
        config: GPTBigCodeConfig,
188
        quant_config: Optional[QuantizationConfig] = None,
189
    ):
190
191
        super().__init__()
        self.config = config
192
        assert not config.add_cross_attention
193
194
195

        self.embed_dim = config.hidden_size

196
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
197
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
198
        self.h = nn.ModuleList([
199
            GPTBigCodeBlock(config, quant_config)
200
201
            for _ in range(config.num_hidden_layers)
        ])
202
203
204
205
206
207
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
208
209
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
210
211
212
213
214
215
216
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
217
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
218
219
220
221
222
223
224

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTBigCodeForCausalLM(nn.Module):

225
226
227
    def __init__(
        self,
        config: GPTBigCodeConfig,
228
        quant_config: Optional[QuantizationConfig] = None,
229
    ):
230
231
        super().__init__()
        self.config = config
232
233
        self.quant_config = quant_config
        self.transformer = GPTBigCodeModel(config, quant_config)
234
        self.lm_head_weight = self.transformer.wte.weight
235
236
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
237
238
239
240
241

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
242
243
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
244
    ) -> torch.Tensor:
245
        hidden_states = self.transformer(input_ids, positions, kv_caches,
246
                                         attn_metadata)
247
248
        return hidden_states

249
250
251
252
253
254
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

255
256
    def sample(
        self,
257
        logits: torch.Tensor,
258
        sampling_metadata: SamplingMetadata,
259
    ) -> Optional[SamplerOutput]:
260
        next_tokens = self.sampler(logits, sampling_metadata)
261
262
        return next_tokens

263
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
264
        params_dict = dict(self.named_parameters(remove_duplicate=False))
265
        for name, loaded_weight in weights:
266
267
268
269
270
271
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
272
273
274
275
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)