gpt2.py 11.8 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25

import torch
from torch import nn
from transformers import GPT2Config

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.config import CacheConfig
28
29
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.sampler import Sampler
38
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors, SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
45
from .utils import is_pp_missing_parameter, make_layers

Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48

class GPT2Attention(nn.Module):

49
50
51
    def __init__(
        self,
        config: GPT2Config,
52
        cache_config: Optional[CacheConfig] = None,
53
        quant_config: Optional[QuantizationConfig] = None,
54
        prefix: str = "",
55
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
59
60
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
64
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
        self.c_attn = QKVParallelLinear(
67
            self.hidden_size,
68
69
            self.head_dim,
            total_num_heads,
70
            bias=True,
71
            quant_config=quant_config,
72
            prefix=f"{prefix}.c_attn",
73
74
75
76
77
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
78
            quant_config=quant_config,
79
            prefix=f"{prefix}.c_proj",
80
        )
81
82
83
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
84
85
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
89

    def forward(
        self,
        hidden_states: torch.Tensor,
90
91
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
95
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
102
103
104
105
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
106
        quant_config: Optional[QuantizationConfig] = None,
107
        prefix: str = "",
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
    ):
        super().__init__()
        hidden_size = config.hidden_size
111
112
113
114
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
115
            quant_config=quant_config,
116
            prefix=f"{prefix}.c_fc",
117
118
119
120
121
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
122
            quant_config=quant_config,
123
            prefix=f"{prefix}.c_proj",
124
        )
125
126
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
127
128
129
130
131
132
133
134
135
136

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

137
138
139
    def __init__(
        self,
        config: GPT2Config,
140
        cache_config: Optional[CacheConfig] = None,
141
        quant_config: Optional[QuantizationConfig] = None,
142
        prefix: str = "",
143
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
        super().__init__()
        hidden_size = config.hidden_size
146
147
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
150
151
152
153
        self.attn = GPT2Attention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
155
156
157
158
        self.mlp = GPT2MLP(inner_dim,
                           config,
                           quant_config,
                           prefix=f"{prefix}.mlp")
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
161
162

    def forward(
        self,
        hidden_states: torch.Tensor,
163
164
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
165
166
167
168
169
170
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
171
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPT2Model(nn.Module):

186
187
188
    def __init__(
        self,
        config: GPT2Config,
189
        cache_config: Optional[CacheConfig] = None,
190
        quant_config: Optional[QuantizationConfig] = None,
191
        prefix: str = "",
192
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
        super().__init__()
        self.config = config
195
196
197
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        self.embed_dim = config.hidden_size
199
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
201
        self.start_layer, self.end_layer, self.h = make_layers(
202
            config.num_hidden_layers,
203
204
205
            lambda prefix: GPT2Block(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
209
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
210
211
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
212
213
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
214
215
216
217
218
219
220
221
222
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            inputs_embeds = self.wte(input_ids)
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Woosuk Kwon's avatar
Woosuk Kwon committed
223

224
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
225
            layer = self.h[i]
226
227
228
229
230
231
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234
235
236
237
238

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPT2LMHeadModel(nn.Module):

239
240
241
    def __init__(
        self,
        config: GPT2Config,
242
        cache_config: Optional[CacheConfig] = None,
243
        quant_config: Optional[QuantizationConfig] = None,
244
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
        super().__init__()
        self.config = config
247
        self.quant_config = quant_config
248
249
250
251
        self.transformer = GPT2Model(config,
                                     cache_config,
                                     quant_config,
                                     prefix="transformer")
252
        self.lm_head = self.transformer.wte
253
254
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257

    def forward(
        self,
258
259
        input_ids: torch.Tensor,
        positions: torch.Tensor,
260
261
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
262
        intermediate_tensors: Optional[IntermediateTensors] = None,
263
    ) -> torch.Tensor:
264
        hidden_states = self.transformer(input_ids, positions, kv_caches,
265
                                         attn_metadata, intermediate_tensors)
266
267
        return hidden_states

268
269
270
271
272
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
273
        logits = self.logits_processor(self.lm_head, hidden_states,
274
275
276
                                       sampling_metadata)
        return logits

277
278
    def sample(
        self,
279
        logits: torch.Tensor,
280
        sampling_metadata: SamplingMetadata,
281
    ) -> Optional[SamplerOutput]:
282
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
        return next_tokens

285
286
287
288
289
290
291
292
293
294
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

295
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
296
        params_dict = dict(self.named_parameters(remove_duplicate=False))
297
        for name, loaded_weight in weights:
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300
301
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
302
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
306
307
            if not name.startswith("transformer."):
                name = "transformer." + name
308
309

            if is_pp_missing_parameter(name, self):
310
                continue
311
312
313
314
315
316
317
318
319
320
321
322
323
324

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)