gpt_bigcode.py 9.91 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
21
from typing import List, Optional
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTBigCodeConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
from vllm.model_executor.layers.sampler import Sampler
35
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
37
from vllm.model_executor.parallel_utils.parallel_state import (
38
    get_tensor_model_parallel_world_size)
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
41
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
42
from vllm.sequence import SamplerOutput
43
44
45
46


class GPTBigCodeAttention(nn.Module):

47
48
49
50
51
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
52
53
54
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
55
        self.tensor_model_parallel_world_size = (
56
            get_tensor_model_parallel_world_size())
57
58
59
        assert total_num_heads % self.tensor_model_parallel_world_size == 0
        self.num_heads = (total_num_heads //
                          self.tensor_model_parallel_world_size)
60
        self.head_dim = self.hidden_size // total_num_heads
61
        self.scale = self.head_dim**-0.5
62

63
64
        self.multi_query = config.multi_query
        if self.multi_query:
65
            total_num_kv_heads = 1
66
67
            self.num_kv_heads = 1
        else:
68
            total_num_kv_heads = total_num_heads
69
            self.num_kv_heads = self.num_heads
70
71
72
73
74
75
76
77
78
        self.kv_dim = self.head_dim * self.num_kv_heads
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            total_num_kv_heads,
            bias=True,
            linear_method=linear_method,
        )
79
80
81
82
83

        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
84
            linear_method=linear_method,
85
        )
86
87
88
89
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              num_kv_heads=self.num_kv_heads)
90
91
92
93

    def forward(
        self,
        hidden_states: torch.Tensor,
94
95
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
96
    ) -> torch.Tensor:
97
98
99
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split(
            [
100
101
102
                self.hidden_size // self.tensor_model_parallel_world_size,
                self.kv_dim, self.kv_dim
            ],
103
104
            dim=-1,
        )
105
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
106
107
108
109
110
111
112
113
114
115
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPTBigMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPTBigCodeConfig,
116
        linear_method: Optional[LinearMethodBase] = None,
117
118
119
    ):
        super().__init__()
        hidden_size = config.hidden_size
120
121
122
123
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
124
            linear_method=linear_method,
125
126
127
128
129
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
130
            linear_method=linear_method,
131
        )
132
133
134
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
135
136
137
138
139
140
141
142
143
144

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPTBigCodeBlock(nn.Module):

145
146
147
148
149
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
150
151
        super().__init__()
        hidden_size = config.hidden_size
152
153
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
154
155

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
        self.attn = GPTBigCodeAttention(config, linear_method)
157
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
158
        self.mlp = GPTBigMLP(inner_dim, config, linear_method)
159
160
161
162

    def forward(
        self,
        hidden_states: torch.Tensor,
163
164
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
165
166
167
168
169
170
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
171
            attn_metadata=attn_metadata,
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPTBigCodeModel(nn.Module):

186
187
188
189
190
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
191
192
        super().__init__()
        self.config = config
193
        assert not config.add_cross_attention
194
195
196

        self.embed_dim = config.hidden_size

197
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
198
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
199
200
201
202
        self.h = nn.ModuleList([
            GPTBigCodeBlock(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
203
204
205
206
207
208
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
209
210
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
211
212
213
214
215
216
217
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
218
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
219
220
221
222
223
224
225

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTBigCodeForCausalLM(nn.Module):

226
227
228
229
230
    def __init__(
        self,
        config: GPTBigCodeConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
231
232
        super().__init__()
        self.config = config
233
234
        self.linear_method = linear_method
        self.transformer = GPTBigCodeModel(config, linear_method)
235
        self.lm_head_weight = self.transformer.wte.weight
236
237
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
238
239
240
241
242

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
243
244
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
245
    ) -> torch.Tensor:
246
        hidden_states = self.transformer(input_ids, positions, kv_caches,
247
                                         attn_metadata)
248
249
        return hidden_states

250
251
252
253
254
255
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

256
257
    def sample(
        self,
258
        logits: torch.Tensor,
259
        sampling_metadata: SamplingMetadata,
260
    ) -> Optional[SamplerOutput]:
261
        next_tokens = self.sampler(logits, sampling_metadata)
262
263
        return next_tokens

264
265
    def load_weights(self,
                     model_name_or_path: str,
266
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
267
268
                     load_format: str = "auto",
                     revision: Optional[str] = None):
269
        params_dict = dict(self.named_parameters(remove_duplicate=False))
270
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
271
                model_name_or_path, cache_dir, load_format, revision):
272
273
274
275
276
277
            if "lm_head.weight" in name:
                continue
            if ".attn.bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
278
279
280
281
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)