gpt_bigcode.py 9.59 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Tuple
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTBigCodeConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.distributed import get_tensor_model_parallel_world_size
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.sampler import Sampler
36
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
38
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
from vllm.sequence import SamplerOutput
41
42
43
44


class GPTBigCodeAttention(nn.Module):

45
46
47
48
49
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
50
51
52
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
53
        self.tensor_model_parallel_world_size = (
54
            get_tensor_model_parallel_world_size())
55
56
57
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
58
        self.head_dim = self.hidden_size // total_num_heads
59
        self.scale = self.head_dim**-0.5
60

61
62
        self.multi_query = config.multi_query
        if self.multi_query:
63
            total_num_kv_heads = 1
64
65
            self.num_kv_heads = 1
        else:
66
            total_num_kv_heads = total_num_heads
67
            self.num_kv_heads = self.num_heads
68
69
70
71
72
73
74
75
76
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
            linear_method=linear_method,
        )
77
78
79
80
81

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
82
            linear_method=linear_method,
83
        )
84
85
86
87
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              num_kv_heads=self.num_kv_heads)
88
89
90
91

    def forward(
        self,
        hidden_states: torch.Tensor,
92
93
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
94
    ) -> torch.Tensor:
95
96
97
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
98
99
100
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
101
102
            dim=-1,
        )
103
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
104
105
106
107
108
109
110
111
112
113
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
114
        linear_method: Optional[LinearMethodBase] = None,
115
116
117
    ):
        super().__init__()
        hidden_size = config.hidden_size
118
119
120
121
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
122
            linear_method=linear_method,
123
124
125
126
127
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
128
            linear_method=linear_method,
129
        )
130
131
132
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
133
134
135
136
137
138
139
140
141
142

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

143
144
145
146
147
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
148
149
        super().__init__()
        hidden_size = config.hidden_size
150
151
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
152
153

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
154
        self.attn = GPTBigCodeAttention(config, linear_method)
155
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
        self.mlp = GPTBigMLP(inner_dim, config, linear_method)
157
158
159
160

    def forward(
        self,
        hidden_states: torch.Tensor,
161
162
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
163
164
165
166
167
168
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
169
            attn_metadata=attn_metadata,
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

184
185
186
187
188
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
189
190
        super().__init__()
        self.config = config
191
        assert not config.add_cross_attention
192
193
194

        self.embed_dim = config.hidden_size

195
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
196
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
197
198
199
200
        self.h = nn.ModuleList([
            GPTBigCodeBlock(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
201
202
203
204
205
206
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
207
208
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
209
210
211
212
213
214
215
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
216
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
217
218
219
220
221
222
223

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTBigCodeForCausalLM(nn.Module):

224
225
226
227
228
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
229
230
        super().__init__()
        self.config = config
231
232
        self.linear_method = linear_method
        self.transformer = GPTBigCodeModel(config, linear_method)
233
        self.lm_head_weight = self.transformer.wte.weight
234
235
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
236
237
238
239
240

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
241
242
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
243
    ) -> torch.Tensor:
244
        hidden_states = self.transformer(input_ids, positions, kv_caches,
245
                                         attn_metadata)
246
247
        return hidden_states

248
249
250
251
252
253
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

254
255
    def sample(
        self,
256
        logits: torch.Tensor,
257
        sampling_metadata: SamplingMetadata,
258
    ) -> Optional[SamplerOutput]:
259
        next_tokens = self.sampler(logits, sampling_metadata)
260
261
        return next_tokens

262
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
263
        params_dict = dict(self.named_parameters(remove_duplicate=False))
264
        for name, loaded_weight in weights:
265
266
267
268
269
270
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
271
272
273
274
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)