gpt2.py 9.62 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25

import torch
from torch import nn
from transformers import GPT2Config

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.sampler import Sampler
35
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
43


class GPT2Attention(nn.Module):

44
45
46
47
48
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
52
53
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
57
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
        self.c_attn = QKVParallelLinear(
60
            self.hidden_size,
61
62
            self.head_dim,
            total_num_heads,
63
            bias=True,
64
            linear_method=linear_method,
65
66
67
68
69
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
70
            linear_method=linear_method,
71
        )
72
        self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
75
76

    def forward(
        self,
        hidden_states: torch.Tensor,
77
78
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
82
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
86
87
88
89
90
91
92
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
93
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
    ):
        super().__init__()
        hidden_size = config.hidden_size
97
98
99
100
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
101
            linear_method=linear_method,
102
103
104
105
106
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
107
            linear_method=linear_method,
108
        )
109
110
111
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116
117
118
119
120
121

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

122
123
124
125
126
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
        super().__init__()
        hidden_size = config.hidden_size
129
130
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
133
        self.attn = GPT2Attention(config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
135
        self.mlp = GPT2MLP(inner_dim, config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139

    def forward(
        self,
        hidden_states: torch.Tensor,
140
141
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
147
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
148
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPT2Model(nn.Module):

163
164
165
166
167
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
        super().__init__()
        self.config = config
170
171
172
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
173
        self.embed_dim = config.hidden_size
174
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
175
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
176
177
178
179
        self.h = nn.ModuleList([
            GPT2Block(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
182
183
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
184
185
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
186
187
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
193
194
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
195
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
198
199
200
201
202

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPT2LMHeadModel(nn.Module):

203
204
205
206
207
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
208
209
        super().__init__()
        self.config = config
210
211
        self.linear_method = linear_method
        self.transformer = GPT2Model(config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
212
        self.lm_head_weight = self.transformer.wte.weight
213
214
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217

    def forward(
        self,
218
219
        input_ids: torch.Tensor,
        positions: torch.Tensor,
220
221
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
222
    ) -> torch.Tensor:
223
        hidden_states = self.transformer(input_ids, positions, kv_caches,
224
                                         attn_metadata)
225
226
        return hidden_states

227
228
229
230
231
232
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

233
234
    def sample(
        self,
235
        logits: torch.Tensor,
236
        sampling_metadata: SamplingMetadata,
237
    ) -> Optional[SamplerOutput]:
238
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
        return next_tokens

241
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
242
        params_dict = dict(self.named_parameters(remove_duplicate=False))
243
        for name, loaded_weight in weights:
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
247
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
248
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
252
253
            if not name.startswith("transformer."):
                name = "transformer." + name
254
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
257
            # Note(zhuohan): the logic below might break quantized models.
Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
260
261
262
263
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
264
265
266
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)