gpt2.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
21
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
22

23
from collections.abc import Iterable
24
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import GPT2Config

30
from vllm.attention.layer import Attention
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed.parallel_state import (
34
35
36
    get_pp_group,
    get_tensor_model_parallel_world_size,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.activation import get_act_fn
38
39
40
41
42
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
43
from vllm.model_executor.layers.logits_processor import LogitsProcessor
44
from vllm.model_executor.layers.pooler import DispatchPooler
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
48
49
    ParallelLMHead,
    VocabParallelEmbedding,
)
50
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
52

53
from .interfaces import SupportsCrossEncoding, SupportsPP
54
55
56
57
58
59
60
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
61

Woosuk Kwon's avatar
Woosuk Kwon committed
62
63

class GPT2Attention(nn.Module):
64
65
66
    def __init__(
        self,
        config: GPT2Config,
67
68
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
69
        prefix: str = "",
70
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
74
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
78
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
79

80
        self.c_attn = QKVParallelLinear(
81
            self.hidden_size,
82
83
            self.head_dim,
            total_num_heads,
84
            bias=True,
85
            quant_config=quant_config,
86
            prefix=f"{prefix}.c_attn",
87
88
89
90
91
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
92
            quant_config=quant_config,
93
            prefix=f"{prefix}.c_proj",
94
        )
95
96
97
98
99
100
101
102
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scale=self.scale,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
106
107
108
109

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
110
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
114
115
116
117
118
119
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
120
        quant_config: QuantizationConfig | None = None,
121
        prefix: str = "",
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
124
    ):
        super().__init__()
        hidden_size = config.hidden_size
125
126
127
128
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
129
            quant_config=quant_config,
130
            prefix=f"{prefix}.c_fc",
131
132
133
134
135
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
136
            quant_config=quant_config,
137
            prefix=f"{prefix}.c_proj",
138
        )
139
        self.act = get_act_fn(config.activation_function)
Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
143
144
145
146
147
148

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):
149
150
151
    def __init__(
        self,
        config: GPT2Config,
152
153
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
154
        prefix: str = "",
155
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
        super().__init__()
        hidden_size = config.hidden_size
158
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
161
162
163
        self.attn = GPT2Attention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
165
        self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168
169
170
171
172

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
173
        attn_output = self.attn(hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
178
179
180
181
182
183
184
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


185
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
186
class GPT2Model(nn.Module):
187
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
188
        super().__init__()
189
190
191
192
193

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
194
        self.config = config
195
196
197
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        self.embed_dim = config.hidden_size
199
200
201
202
203
204
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.wte",
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
205
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
206
        self.start_layer, self.end_layer, self.h = make_layers(
207
            config.num_hidden_layers,
208
209
210
            lambda prefix: GPT2Block(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h",
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
211
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
212
213
214
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.n_embd
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
215

216
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
217
218
        return self.wte(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
221
        input_ids: torch.Tensor,
222
        position_ids: torch.Tensor,
223
224
225
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None,
    ) -> torch.Tensor | IntermediateTensors:
226
        if get_pp_group().is_first_rank:
227
            if inputs_embeds is None:
228
                inputs_embeds = self.embed_input_ids(input_ids)
229
230
231
232
233
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Woosuk Kwon's avatar
Woosuk Kwon committed
234

235
        for layer in islice(self.h, self.start_layer, self.end_layer):
236
            hidden_states = layer(hidden_states)
237
238
239

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
243

        hidden_states = self.ln_f(hidden_states)
        return hidden_states

244
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
266
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
267
268
269
270
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Woosuk Kwon's avatar
Woosuk Kwon committed
271

272
class GPT2LMHeadModel(nn.Module, SupportsPP):
273
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
274
        super().__init__()
275
276
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
277
        self.config = config
278
        self.quant_config = quant_config
279
280
281
282
283
284
285
286
287
        self.transformer = GPT2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.lm_head",
        )
288
        if self.config.tie_word_embeddings:
289
290
            self.lm_head = self.lm_head.tie_weights(self.transformer.wte)

291
        self.logits_processor = LogitsProcessor(config.vocab_size)
292
        self.make_empty_intermediate_tensors = (
293
294
            self.transformer.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
295

296
297
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.embed_input_ids(input_ids)
298

Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
301
        input_ids: torch.Tensor,
302
        positions: torch.Tensor,
303
304
305
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
306
307
308
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
309
310
        return hidden_states

311
312
313
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
314
    ) -> torch.Tensor | None:
315
        logits = self.logits_processor(self.lm_head, hidden_states)
316
317
        return logits

318
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
319
320
321
322
323
        loader = AutoWeightsLoader(self)
        weights = _add_transformer_prefix(weights)
        return loader.load_weights(weights)


324
class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
325
326
327
328
329
330
331
332
333
334
335
    """GPT2 Model for sequence classification.

    This class expands GPT2Model with pooling and score functions - last token
    is being used for classification.

    Attributes:
        transformer: An instance of GPT2Model used for forward operations.
        score: A layer for calculating logits.
        _pooler: An instance of Pooler used for pooling operations.
    """

336
337
    is_pooling_model = True

338
339
340
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
341
342
343
344
345
346
347
348
349
        self.transformer = GPT2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2")
        )
        self.score = nn.Linear(
            config.n_embd,
            config.num_labels,
            bias=False,
            dtype=vllm_config.model_config.head_dtype,
        )
350

351
        pooler_config = vllm_config.model_config.pooler_config
352
353
        assert pooler_config is not None

354
        self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
355

356
357
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.embed_input_ids(input_ids)
358

359
360
361
362
363
364
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
365
        input_ids: torch.Tensor,
366
        positions: torch.Tensor,
367
368
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
369
370
371
372
373
    ) -> torch.Tensor:
        hidden_states = self.transformer(
            input_ids=input_ids,
            position_ids=positions,
            inputs_embeds=inputs_embeds,
374
375
            intermediate_tensors=intermediate_tensors,
        )
376
        return hidden_states
377
378


379
def _add_transformer_prefix(
380
    weights: Iterable[tuple[str, torch.Tensor]],
381
382
) -> Iterable[tuple[str, torch.Tensor]]:
    for name, tensor in weights:
383
384
        if not name.startswith("transformer.") and not name.startswith("lm_head"):
            name = "transformer." + name
zhuwenwen's avatar
zhuwenwen committed
385
        yield name, tensor