gpt2.py 9.78 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
20
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25

import torch
from torch import nn
from transformers import GPT2Config

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
28
from vllm.model_executor.layers.attention import Attention
29
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.sampler import Sampler
34
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.parallel_utils.parallel_state import (
37
    get_tensor_model_parallel_world_size)
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
40
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
41
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44
45
46
47

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPT2Attention(nn.Module):

48
49
50
51
52
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
56
57
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
61
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
        self.c_attn = QKVParallelLinear(
64
            self.hidden_size,
65
66
            self.head_dim,
            total_num_heads,
67
            bias=True,
68
            linear_method=linear_method,
69
70
71
72
73
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
74
            linear_method=linear_method,
75
        )
76
        self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
81
82
83
84
85
86

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        key_cache, value_cache = kv_cache
87
        attn_output = self.attn(q, k, v, key_cache, value_cache,
88
                                input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92
93
94
95
96
97
98
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
99
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
    ):
        super().__init__()
        hidden_size = config.hidden_size
103
104
105
106
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
107
            linear_method=linear_method,
108
109
110
111
112
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
113
            linear_method=linear_method,
114
        )
115
116
117
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
121
122
123
124
125
126
127

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

128
129
130
131
132
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
        super().__init__()
        hidden_size = config.hidden_size
135
136
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
139
        self.attn = GPT2Attention(config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
140
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
141
        self.mlp = GPT2MLP(inner_dim, config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPT2Model(nn.Module):

169
170
171
172
173
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
        super().__init__()
        self.config = config
176
177
178
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
179
        self.embed_dim = config.hidden_size
180
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
181
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
182
183
184
185
        self.h = nn.ModuleList([
            GPT2Block(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
188
189
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
190
191
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194
195
196
197
198
199
200
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
201
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203
204
205
206
207
208

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPT2LMHeadModel(nn.Module):

209
210
211
212
213
    def __init__(
        self,
        config: GPT2Config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
        super().__init__()
        self.config = config
216
217
        self.linear_method = linear_method
        self.transformer = GPT2Model(config, linear_method)
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
221
222
        self.lm_head_weight = self.transformer.wte.weight
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
223
224
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
227
    ) -> torch.Tensor:
228
        hidden_states = self.transformer(input_ids, positions, kv_caches,
229
                                         input_metadata)
230
231
232
233
234
235
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
236
    ) -> Optional[SamplerOutput]:
237
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
238
                                   sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
        return next_tokens

241
242
    def load_weights(self,
                     model_name_or_path: str,
Woosuk Kwon's avatar
Woosuk Kwon committed
243
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
244
245
                     load_format: str = "auto",
                     revision: Optional[str] = None):
246
        params_dict = dict(self.named_parameters(remove_duplicate=False))
Woosuk Kwon's avatar
Woosuk Kwon committed
247
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
248
                model_name_or_path, cache_dir, load_format, revision):
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
252
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
253
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
257
258
            if not name.startswith("transformer."):
                name = "transformer." + name
259
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
260
261
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
262
            # Note(zhuohan): the logic below might break quantized models.
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
267
268
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
269
270
271
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)