gpt_bigcode.py 14.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
25
from typing import List, Optional, Tuple
26
27
28
29
30
31
32
33
34

import torch
from torch import nn
from transformers import GPTBigCodeConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
JFDuan's avatar
JFDuan committed
35
from vllm.model_executor.weight_utils import (
36
37
    convert_pyslice_to_tensor, hf_model_weights_iterator,
    load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
38
39
40
41
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
42
from vllm.sequence import SamplerOutput
43
44
45
46
47
48
49
50
51
52

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPTBigCodeAttention(nn.Module):

    def __init__(self, config: GPTBigCodeConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
53
        self.tensor_model_parallel_world_size = (
54
            get_tensor_model_parallel_world_size())
55
56
57
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
58
        self.head_dim = self.hidden_size // total_num_heads
59
        self.scale = self.head_dim**-0.5
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        self.multi_query = config.multi_query
        if self.multi_query:
            self.num_kv_heads = 1
            self.kv_dim = self.head_dim
            self.c_attn_q = ColumnParallelLinear(self.hidden_size,
                                                 self.hidden_size,
                                                 bias=True,
                                                 gather_output=False,
                                                 perform_initialization=False)
            self.c_attn_kv = nn.Linear(self.hidden_size,
                                       2 * self.kv_dim,
                                       bias=True)
        else:
            self.num_kv_heads = self.num_heads
            self.kv_dim = self.num_kv_heads * self.head_dim
            self.c_attn = ColumnParallelLinear(self.hidden_size,
                                               self.hidden_size +
                                               2 * self.kv_dim,
                                               bias=True,
                                               gather_output=False,
                                               perform_initialization=False)

83
84
85
86
        self.c_proj = RowParallelLinear(self.hidden_size,
                                        self.hidden_size,
                                        bias=True,
                                        input_is_parallel=True,
87
                                        perform_initialization=False)
88
89
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
Zhuohan Li's avatar
Zhuohan Li committed
90
91
                                   scale=self.scale,
                                   num_kv_heads=self.num_kv_heads)
92
93
94
95
96
97
98
99

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
100
101
102
103
104
105
        if self.multi_query:
            q, _ = self.c_attn_q(hidden_states)
            kv = self.c_attn_kv(hidden_states)
            k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
        else:
            qkv, _ = self.c_attn(hidden_states)
106
107
108
109
            q, k, v = qkv.split([
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
110
                                dim=-1)
111
        key_cache, value_cache = kv_cache
112
113
        attn_output = self.attn(q, k, v, key_cache, value_cache,
                                input_metadata, cache_event)
114
115
116
117
118
119
120
121
122
123
124
125
126
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
    ):
        super().__init__()
        hidden_size = config.hidden_size
127
128
129
130
        self.c_fc = ColumnParallelLinear(hidden_size,
                                         intermediate_size,
                                         bias=True,
                                         gather_output=False,
131
                                         perform_initialization=False)
132
133
134
135
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=True,
                                        input_is_parallel=True,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                                        perform_initialization=False)
        self.act = get_act_fn(config.activation_function)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

    def __init__(self, config: GPTBigCodeConfig):
        super().__init__()
        hidden_size = config.hidden_size
151
152
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPTBigCodeAttention(config)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPTBigMLP(inner_dim, config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

    def __init__(self, config: GPTBigCodeConfig):
        super().__init__()
        self.config = config
190
        assert not config.add_cross_attention
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

        self.embed_dim = config.hidden_size

        # Optimization: While the vocab size of GPT-2 is 50257, we extend it
        # to 50304 in order to make it divisible by 64.
        # This improves performance since GPUs are faster if the dimension
        # is divisible by 64. In addition, it allows us to shard the embedding
        # layer across 2, 4, 8, or more GPUs.
        vocab_size = ((config.vocab_size + 63) // 64) * 64
        self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
        self.h = nn.ModuleList(
            [GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.h[i]
224
225
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
                                  cache_event)
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTBigCodeForCausalLM(nn.Module):

    def __init__(self, config: GPTBigCodeConfig):
        super().__init__()
        self.config = config
        self.transformer = GPTBigCodeModel(config)
        # TODO(zhuohan): create a new weight after implementing pipeline
        #                parallelism
        self.lm_head_weight = self.transformer.wte.weight
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
249
    ) -> SamplerOutput:
250
251
252
253
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
                                   input_metadata)
254
255
        return next_tokens

JFDuan's avatar
JFDuan committed
256
    _column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
257
258
    _row_parallel_weights = ["c_proj.weight"]

259
260
    def load_weights(self,
                     model_name_or_path: str,
261
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
262
263
                     load_format: str = "auto",
                     revision: Optional[str] = None):
264
265
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
266
267
268
269
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()

        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
270
                model_name_or_path, cache_dir, load_format, revision):
271
272
273
274
275
276
277
278
279
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue

280
281
282
            if not name.startswith("transformer."):
                name = "transformer." + name

283
284
            # For the fused QKV linear layer, manually shard the weights.
            if "c_attn" in name:
285
286
287
288
                # GPT-2's fused QKV has the shape of
                # [3 * num_heads * head_size, hidden_size].
                # When tensor parallelism is used, we shard the weights along
                # the head dimension.
289
                total_num_heads = self.config.num_attention_heads
Zhuohan Li's avatar
Zhuohan Li committed
290
291
                total_num_kv_heads = (1 if self.config.multi_query else
                                      total_num_heads)
292
293
                hidden_size = self.config.hidden_size
                head_size = hidden_size // total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
294
                total_kv_size = head_size * total_num_kv_heads
295
                num_heads = total_num_heads // tensor_model_parallel_world_size
296
297
298
                head_start = tensor_model_parallel_rank * num_heads
                head_end = (tensor_model_parallel_rank + 1) * num_heads

299
                loaded_weight = convert_pyslice_to_tensor(loaded_weight)
Zhuohan Li's avatar
Zhuohan Li committed
300
301
302
303
304
305
306
307
308
                wq, wk, wv = torch.split(
                    loaded_weight, [hidden_size, total_kv_size, total_kv_size],
                    dim=0)

                wq = wq[head_size * head_start:head_size * head_end]
                if not self.config.multi_query:
                    # Split the heads when using normal multi-head attention
                    wk = wk[head_size * head_start:head_size * head_end]
                    wv = wv[head_size * head_start:head_size * head_end]
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
                    loaded_weight = torch.cat([wq, wk, wv], dim=0)
                else:
                    # For multi-query attention, we split the query
                    # but replicate the key and value.
                    loaded_weight_q = wq
                    loaded_weight_kv = torch.cat([wk, wv], dim=0)
                    q_weight_name = name.replace("c_attn", "c_attn_q")
                    kv_weight_name = name.replace("c_attn", "c_attn_kv")
                    load_tensor_parallel_weights(state_dict[q_weight_name],
                                                 loaded_weight_q,
                                                 q_weight_name,
                                                 self._column_parallel_weights,
                                                 self._row_parallel_weights,
                                                 tensor_model_parallel_rank)
                    load_tensor_parallel_weights(state_dict[kv_weight_name],
                                                 loaded_weight_kv,
                                                 kv_weight_name,
                                                 self._column_parallel_weights,
                                                 self._row_parallel_weights,
                                                 tensor_model_parallel_rank)
                    continue

            param = state_dict[name]
Zhuohan Li's avatar
Zhuohan Li committed
332

333
            if name == "transformer.wte.weight":
JFDuan's avatar
JFDuan committed
334
335
336
                load_padded_tensor_parallel_vocab(param, loaded_weight,
                                                  tensor_model_parallel_rank)
                continue
337

338
339
340
341
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
                                         self._row_parallel_weights,
                                         tensor_model_parallel_rank)