gpt2.py 9.97 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25

import torch
from torch import nn
from transformers import GPT2Config

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.config import CacheConfig
28
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
35
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.sampler import Sampler
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44
45


class GPT2Attention(nn.Module):

46
47
48
    def __init__(
        self,
        config: GPT2Config,
49
        cache_config: Optional[CacheConfig] = None,
50
        quant_config: Optional[QuantizationConfig] = None,
51
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
55
56
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
59
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
60
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
61

62
        self.c_attn = QKVParallelLinear(
63
            self.hidden_size,
64
65
            self.head_dim,
            total_num_heads,
66
            bias=True,
67
            quant_config=quant_config,
68
69
70
71
72
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
73
            quant_config=quant_config,
74
        )
75
76
77
78
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              cache_config=cache_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82

    def forward(
        self,
        hidden_states: torch.Tensor,
83
84
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
88
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92
93
94
95
96
97
98
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
99
        quant_config: Optional[QuantizationConfig] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
    ):
        super().__init__()
        hidden_size = config.hidden_size
103
104
105
106
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
107
            quant_config=quant_config,
108
109
110
111
112
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
113
            quant_config=quant_config,
114
        )
115
116
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119
120
121
122
123
124
125
126

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

127
128
129
    def __init__(
        self,
        config: GPT2Config,
130
        cache_config: Optional[CacheConfig] = None,
131
        quant_config: Optional[QuantizationConfig] = None,
132
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
        super().__init__()
        hidden_size = config.hidden_size
135
136
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
139
        self.attn = GPT2Attention(config, cache_config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
140
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
141
        self.mlp = GPT2MLP(inner_dim, config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145

    def forward(
        self,
        hidden_states: torch.Tensor,
146
147
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
150
151
152
153
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
154
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPT2Model(nn.Module):

169
170
171
    def __init__(
        self,
        config: GPT2Config,
172
        cache_config: Optional[CacheConfig] = None,
173
        quant_config: Optional[QuantizationConfig] = None,
174
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
        super().__init__()
        self.config = config
177
178
179
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
180
        self.embed_dim = config.hidden_size
181
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
182
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
183
        self.h = nn.ModuleList([
184
            GPT2Block(config, cache_config, quant_config)
185
186
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Woosuk Kwon committed
187
188
189
190
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
191
192
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
193
194
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
197
198
199
200
201
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        for i in range(len(self.h)):
            layer = self.h[i]
202
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
206
207
208
209

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPT2LMHeadModel(nn.Module):

210
211
212
    def __init__(
        self,
        config: GPT2Config,
213
        cache_config: Optional[CacheConfig] = None,
214
        quant_config: Optional[QuantizationConfig] = None,
215
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
        super().__init__()
        self.config = config
218
        self.quant_config = quant_config
219
        self.transformer = GPT2Model(config, cache_config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
220
        self.lm_head_weight = self.transformer.wte.weight
221
222
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225

    def forward(
        self,
226
227
        input_ids: torch.Tensor,
        positions: torch.Tensor,
228
229
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
230
    ) -> torch.Tensor:
231
        hidden_states = self.transformer(input_ids, positions, kv_caches,
232
                                         attn_metadata)
233
234
        return hidden_states

235
236
237
238
239
240
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

241
242
    def sample(
        self,
243
        logits: torch.Tensor,
244
        sampling_metadata: SamplingMetadata,
245
    ) -> Optional[SamplerOutput]:
246
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
        return next_tokens

249
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
250
        params_dict = dict(self.named_parameters(remove_duplicate=False))
251
        for name, loaded_weight in weights:
Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
254
255
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
256
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
260
261
            if not name.startswith("transformer."):
                name = "transformer." + name
262
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
265
            # Note(zhuohan): the logic below might break quantized models.
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
268
269
270
271
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
272
273
274
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)