gpt_bigcode.py 14.2 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple

import torch
from torch import nn
import numpy as np
from transformers import GPTBigCodeConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
                                              load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPTBigCodeAttention(nn.Module):

    def __init__(self, config: GPTBigCodeConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
53
54
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
55
56
57
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
58
        self.scale = self.head_dim**-0.5
59

60
61
62
63
        self.c_attn = ColumnParallelLinear(self.hidden_size,
                                           3 * self.hidden_size,
                                           bias=True,
                                           gather_output=False,
64
                                           perform_initialization=False)
65
66
67
68
        self.c_proj = RowParallelLinear(self.hidden_size,
                                        self.hidden_size,
                                        bias=True,
                                        input_is_parallel=True,
69
                                        perform_initialization=False)
70
71
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
72
73
74
75
76
77
78
79
80
81
82
83
                                   scale=self.scale)

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        key_cache, value_cache = kv_cache
84
85
        attn_output = self.attn(q, k, v, key_cache, value_cache,
                                input_metadata, cache_event)
86
87
88
89
90
91
92
93
94
95
96
97
98
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
    ):
        super().__init__()
        hidden_size = config.hidden_size
99
100
101
102
        self.c_fc = ColumnParallelLinear(hidden_size,
                                         intermediate_size,
                                         bias=True,
                                         gather_output=False,
103
                                         perform_initialization=False)
104
105
106
107
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=True,
                                        input_is_parallel=True,
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
                                        perform_initialization=False)
        self.act = get_act_fn(config.activation_function)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

    def __init__(self, config: GPTBigCodeConfig):
        super().__init__()
        hidden_size = config.hidden_size
123
124
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPTBigCodeAttention(config)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPTBigMLP(inner_dim, config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

    def __init__(self, config: GPTBigCodeConfig):
        super().__init__()
        self.config = config
162
        assert not config.add_cross_attention
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

        self.embed_dim = config.hidden_size

        # Optimization: While the vocab size of GPT-2 is 50257, we extend it
        # to 50304 in order to make it divisible by 64.
        # This improves performance since GPUs are faster if the dimension
        # is divisible by 64. In addition, it allows us to shard the embedding
        # layer across 2, 4, 8, or more GPUs.
        vocab_size = ((config.vocab_size + 63) // 64) * 64
        self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
        self.h = nn.ModuleList(
            [GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.h[i]
196
197
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
                                  cache_event)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTBigCodeForCausalLM(nn.Module):

    def __init__(self, config: GPTBigCodeConfig):
        super().__init__()
        self.config = config
        self.transformer = GPTBigCodeModel(config)
        # TODO(zhuohan): create a new weight after implementing pipeline
        #                parallelism
        self.lm_head_weight = self.transformer.wte.weight
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> Dict[int, SequenceOutputs]:
222
223
224
225
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
                                   input_metadata)
226
227
228
229
230
        return next_tokens

    _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
    _row_parallel_weights = ["c_proj.weight"]

231
232
    def load_weights(self,
                     model_name_or_path: str,
233
234
                     cache_dir: Optional[str] = None,
                     use_np_cache: bool = False):
235
236
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
237
238
239
240
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()

        for name, loaded_weight in hf_model_weights_iterator(
241
                model_name_or_path, cache_dir, use_np_cache):
242
243
244
245
246
247
248
249
250
251
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue

            param = state_dict[name]
252
253
254
255
256
257

            if not name.startswith("transformer."):
                name = "transformer." + name

            if name == "transformer.wte.weight":
                # Consider padding in the vocab size.
258
259
                padded_vocab_size = param.shape[
                    0] * tensor_model_parallel_world_size
260
                num_extra_rows = padded_vocab_size - self.config.vocab_size
261
262
                extra_rows = torch.empty(num_extra_rows,
                                         loaded_weight.shape[1])
263
264
265
                extra_rows = extra_rows.to(loaded_weight)
                loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

266
267
268
269
270
            def _expand_mqa_mha(qkv_array, n_head, head_dim):
                """manipulates along axis=0 from MQA to MHA
                inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
                    with n_heads for q, then 1 for k, 1 for 1 v, times head dim
                return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
271

272
273
274
                TODO: this function is no longer needed once vllm supports MQA.
                """
                qkv_array = qkv_array.numpy()
275

276
                dims_q = n_head * head_dim
277
278
279
280
281
282
283
                # pylint: disable=unbalanced-tuple-unpacking
                q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
                                   axis=0)
                # q is fine, but k & v have not replicated shape along the first
                # axis as long as MQA is not nativly supported, increase memory
                # and replicated (head_dim, hidden_dim) to
                # (n_heads * head_dim, hidden_dim)
284
285
286
287
288
289
                if k.ndim == 2 and v.ndim == 2:
                    replication = (n_head, 1)  # weights
                else:
                    replication = n_head  # biases
                # replicate n_head times for q, v
                k, v = np.tile(k, replication), np.tile(v, replication)
290
291
                # concat q, k, v along the first axis
                # (n_heads * head_dim, hidden_dim)
292
293
294
295
296
297
                # to (3 * n_heads * head_dim, hidden_dim)
                qkv_array = np.concatenate((q, k, v), axis=0)
                return torch.from_numpy(qkv_array)

            # For the fused QKV linear layer, manually shard the weights.
            if "c_attn" in name:
298
299
300
301
                # GPT-2's fused QKV has the shape of
                # [3 * num_heads * head_size, hidden_size].
                # When tensor parallelism is used, we shard the weights along
                # the head dimension.
302
303
304
305
306
307
308
309
                total_num_heads = self.config.num_attention_heads
                hidden_size = self.config.hidden_size
                head_size = hidden_size // total_num_heads
                num_heads = total_num_heads // tensor_model_parallel_world_size
                head_start = tensor_model_parallel_rank * num_heads
                head_end = (tensor_model_parallel_rank + 1) * num_heads

                if name.endswith(".weight"):
310
311
312
313
314
                    loaded_weight = _expand_mqa_mha(loaded_weight,
                                                    n_head=total_num_heads,
                                                    head_dim=head_size)
                    loaded_weight = loaded_weight.view(3, total_num_heads,
                                                       head_size, hidden_size)
315
316
317
                    loaded_weight = loaded_weight[:, head_start:head_end, :, :]
                    loaded_weight = loaded_weight.reshape(-1, hidden_size)
                elif name.endswith(".bias"):
318
319
320
321
322
                    loaded_weight = _expand_mqa_mha(loaded_weight,
                                                    n_head=total_num_heads,
                                                    head_dim=head_size)
                    loaded_weight = loaded_weight.view(3, total_num_heads,
                                                       head_size)
323
324
325
326
                    loaded_weight = loaded_weight[:, head_start:head_end, :]
                    loaded_weight = loaded_weight.reshape(-1)
                else:
                    raise ValueError(f"Unexpected parameter name {name}")
327

328
329
330
331
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
                                         self._row_parallel_weights,
                                         tensor_model_parallel_rank)