gpt_bigcode.py 10.1 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import List, Optional, Tuple
22
23
24
25
26
27
28

import torch
from torch import nn
from transformers import GPTBigCodeConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
29
from vllm.model_executor.layers.attention import Attention
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.sampler import Sampler
36
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
38
from vllm.model_executor.parallel_utils.parallel_state import (
39
    get_tensor_model_parallel_world_size)
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
42
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
43
from vllm.sequence import SamplerOutput
44
45
46
47
48
49

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPTBigCodeAttention(nn.Module):

50
51
52
53
54
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
55
56
57
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
58
        self.tensor_model_parallel_world_size = (
59
            get_tensor_model_parallel_world_size())
60
61
62
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
63
        self.head_dim = self.hidden_size // total_num_heads
64
        self.scale = self.head_dim**-0.5
65

66
67
        self.multi_query = config.multi_query
        if self.multi_query:
68
            total_num_kv_heads = 1
69
70
            self.num_kv_heads = 1
        else:
71
            total_num_kv_heads = total_num_heads
72
            self.num_kv_heads = self.num_heads
73
74
75
76
77
78
79
80
81
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
            linear_method=linear_method,
        )
82
83
84
85
86

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
87
            linear_method=linear_method,
88
        )
89
90
91
92
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              num_kv_heads=self.num_kv_heads)
93
94
95
96
97
98
99

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
100
101
102
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
103
104
105
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
106
107
            dim=-1,
        )
108
        key_cache, value_cache = kv_cache
109
        attn_output = self.attn(q, k, v, key_cache, value_cache,
110
                                input_metadata)
111
112
113
114
115
116
117
118
119
120
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
121
        linear_method: Optional[LinearMethodBase] = None,
122
123
124
    ):
        super().__init__()
        hidden_size = config.hidden_size
125
126
127
128
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
129
            linear_method=linear_method,
130
131
132
133
134
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
135
            linear_method=linear_method,
136
        )
137
138
139
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
140
141
142
143
144
145
146
147
148
149

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

150
151
152
153
154
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
155
156
        super().__init__()
        hidden_size = config.hidden_size
157
158
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
159
160

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
161
        self.attn = GPTBigCodeAttention(config, linear_method)
162
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
163
        self.mlp = GPTBigMLP(inner_dim, config, linear_method)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

191
192
193
194
195
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
196
197
        super().__init__()
        self.config = config
198
        assert not config.add_cross_attention
199
200
201

        self.embed_dim = config.hidden_size

202
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
203
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
204
205
206
207
        self.h = nn.ModuleList([
            GPTBigCodeBlock(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
223
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
224
225
226
227
228
229
230

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTBigCodeForCausalLM(nn.Module):

231
232
233
234
235
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
236
237
        super().__init__()
        self.config = config
238
239
        self.linear_method = linear_method
        self.transformer = GPTBigCodeModel(config, linear_method)
240
        self.lm_head_weight = self.transformer.wte.weight
241
242
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
243
244
245
246
247
248
249

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
250
    ) -> torch.Tensor:
251
        hidden_states = self.transformer(input_ids, positions, kv_caches,
252
                                         input_metadata)
253
254
        return hidden_states

255
256
257
258
259
260
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

261
262
    def sample(
        self,
263
        logits: torch.Tensor,
264
        sampling_metadata: SamplingMetadata,
265
    ) -> Optional[SamplerOutput]:
266
        next_tokens = self.sampler(logits, sampling_metadata)
267
268
        return next_tokens

269
270
    def load_weights(self,
                     model_name_or_path: str,
271
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
272
273
                     load_format: str = "auto",
                     revision: Optional[str] = None):
274
        params_dict = dict(self.named_parameters(remove_duplicate=False))
275
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
276
                model_name_or_path, cache_dir, load_format, revision):
277
278
279
280
281
282
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
283
284
285
286
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)