gpt2.py 9.91 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
20
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25

import torch
from torch import nn
from transformers import GPT2Config

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.sampler import Sampler
35
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
38
39
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
40
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
44


class GPT2Attention(nn.Module):

45
46
47
48
49
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
53
54
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
58
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
59

60
        self.c_attn = QKVParallelLinear(
61
            self.hidden_size,
62
63
            self.head_dim,
            total_num_heads,
64
            bias=True,
65
            linear_method=linear_method,
66
67
68
69
70
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
71
            linear_method=linear_method,
72
        )
73
        self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
76
77

    def forward(
        self,
        hidden_states: torch.Tensor,
78
79
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
83
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
87
88
89
90
91
92
93
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
94
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
    ):
        super().__init__()
        hidden_size = config.hidden_size
98
99
100
101
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
102
            linear_method=linear_method,
103
104
105
106
107
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
108
            linear_method=linear_method,
109
        )
110
111
112
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
117
118
119
120
121
122

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

123
124
125
126
127
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
        super().__init__()
        hidden_size = config.hidden_size
130
131
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
134
        self.attn = GPT2Attention(config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
135
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
136
        self.mlp = GPT2MLP(inner_dim, config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140

    def forward(
        self,
        hidden_states: torch.Tensor,
141
142
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
145
146
147
148
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
149
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPT2Model(nn.Module):

164
165
166
167
168
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
        super().__init__()
        self.config = config
171
172
173
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        self.embed_dim = config.hidden_size
175
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
176
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
177
178
179
180
        self.h = nn.ModuleList([
            GPT2Block(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Woosuk Kwon committed
181
182
183
184
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
185
186
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
187
188
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
191
192
193
194
195
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
196
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
200
201
202
203

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPT2LMHeadModel(nn.Module):

204
205
206
207
208
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
        super().__init__()
        self.config = config
211
212
        self.linear_method = linear_method
        self.transformer = GPT2Model(config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        self.lm_head_weight = self.transformer.wte.weight
214
215
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218

    def forward(
        self,
219
220
        input_ids: torch.Tensor,
        positions: torch.Tensor,
221
222
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
223
    ) -> torch.Tensor:
224
        hidden_states = self.transformer(input_ids, positions, kv_caches,
225
                                         attn_metadata)
226
227
        return hidden_states

228
229
230
231
232
233
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

234
235
    def sample(
        self,
236
        logits: torch.Tensor,
237
        sampling_metadata: SamplingMetadata,
238
    ) -> Optional[SamplerOutput]:
239
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
        return next_tokens

242
243
    def load_weights(self,
                     model_name_or_path: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
244
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
245
246
                     load_format: str = "auto",
                     revision: Optional[str] = None):
247
        params_dict = dict(self.named_parameters(remove_duplicate=False))
Woosuk Kwon's avatar
Woosuk Kwon committed
248
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
249
                model_name_or_path, cache_dir, load_format, revision):
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
253
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
254
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
258
259
            if not name.startswith("transformer."):
                name = "transformer." + name
260
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
263
            # Note(zhuohan): the logic below might break quantized models.
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
267
268
269
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
270
271
272
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)