gpt2.py 10 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25

import torch
from torch import nn
from transformers import GPT2Config

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.config import CacheConfig
28
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
35
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.sampler import Sampler
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44
45


class GPT2Attention(nn.Module):

46
47
48
    def __init__(
        self,
        config: GPT2Config,
49
        cache_config: Optional[CacheConfig] = None,
50
        quant_config: Optional[QuantizationConfig] = None,
51
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
55
56
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
59
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
60
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
        self.c_attn = QKVParallelLinear(
63
            self.hidden_size,
64
65
            self.head_dim,
            total_num_heads,
66
            bias=True,
67
            quant_config=quant_config,
68
69
70
71
72
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
73
            quant_config=quant_config,
74
        )
75
76
77
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
78
79
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83

    def forward(
        self,
        hidden_states: torch.Tensor,
84
85
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
89
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
93
94
95
96
97
98
99
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
100
        quant_config: Optional[QuantizationConfig] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
    ):
        super().__init__()
        hidden_size = config.hidden_size
104
105
106
107
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
108
            quant_config=quant_config,
109
110
111
112
113
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
114
            quant_config=quant_config,
115
        )
116
117
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
121
122
123
124
125
126
127

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

128
129
130
    def __init__(
        self,
        config: GPT2Config,
131
        cache_config: Optional[CacheConfig] = None,
132
        quant_config: Optional[QuantizationConfig] = None,
133
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
134
135
        super().__init__()
        hidden_size = config.hidden_size
136
137
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
140
        self.attn = GPT2Attention(config, cache_config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
141
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
142
        self.mlp = GPT2MLP(inner_dim, config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
145
146

    def forward(
        self,
        hidden_states: torch.Tensor,
147
148
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
152
153
154
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
155
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPT2Model(nn.Module):

170
171
172
    def __init__(
        self,
        config: GPT2Config,
173
        cache_config: Optional[CacheConfig] = None,
174
        quant_config: Optional[QuantizationConfig] = None,
175
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
        super().__init__()
        self.config = config
178
179
180
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
181
        self.embed_dim = config.hidden_size
182
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
183
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
184
        self.h = nn.ModuleList([
185
            GPT2Block(config, cache_config, quant_config)
186
187
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
192
193
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
194
195
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
198
199
200
201
202
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
203
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
207
208
209
210

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPT2LMHeadModel(nn.Module):

211
212
213
    def __init__(
        self,
        config: GPT2Config,
214
        cache_config: Optional[CacheConfig] = None,
215
        quant_config: Optional[QuantizationConfig] = None,
216
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
        super().__init__()
        self.config = config
219
        self.quant_config = quant_config
220
        self.transformer = GPT2Model(config, cache_config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
221
        self.lm_head_weight = self.transformer.wte.weight
222
223
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
226

    def forward(
        self,
227
228
        input_ids: torch.Tensor,
        positions: torch.Tensor,
229
230
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
231
    ) -> torch.Tensor:
232
        hidden_states = self.transformer(input_ids, positions, kv_caches,
233
                                         attn_metadata)
234
235
        return hidden_states

236
237
238
239
240
241
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

242
243
    def sample(
        self,
244
        logits: torch.Tensor,
245
        sampling_metadata: SamplingMetadata,
246
    ) -> Optional[SamplerOutput]:
247
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
        return next_tokens

250
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
251
        params_dict = dict(self.named_parameters(remove_duplicate=False))
252
        for name, loaded_weight in weights:
Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
255
256
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
257
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
260
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
261
262
            if not name.startswith("transformer."):
                name = "transformer." + name
263
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
266
            # Note(zhuohan): the logic below might break quantized models.
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
269
270
271
272
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
273
274
275
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)