gpt2.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
21
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
22
23
from collections.abc import Iterable
from typing import Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
28

import torch
from torch import nn
from transformers import GPT2Config

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
33
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import get_act_fn
35
36
37
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
45

46
from .interfaces import SupportsPP
47
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
48
49
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
50

Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53

class GPT2Attention(nn.Module):

54
55
56
    def __init__(
        self,
        config: GPT2Config,
57
        cache_config: Optional[CacheConfig] = None,
58
        quant_config: Optional[QuantizationConfig] = None,
59
        prefix: str = "",
60
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
64
65
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
69
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71
        self.c_attn = QKVParallelLinear(
72
            self.hidden_size,
73
74
            self.head_dim,
            total_num_heads,
75
            bias=True,
76
            quant_config=quant_config,
77
            prefix=f"{prefix}.c_attn",
78
79
80
81
82
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
83
            quant_config=quant_config,
84
            prefix=f"{prefix}.c_proj",
85
        )
86
87
88
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
89
                              cache_config=cache_config,
90
91
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
96
97
98

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
99
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
103
104
105
106
107
108
109
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
110
        quant_config: Optional[QuantizationConfig] = None,
111
        prefix: str = "",
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
    ):
        super().__init__()
        hidden_size = config.hidden_size
115
116
117
118
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
119
            quant_config=quant_config,
120
            prefix=f"{prefix}.c_fc",
121
122
123
124
125
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
126
            quant_config=quant_config,
127
            prefix=f"{prefix}.c_proj",
128
        )
129
        self.act = get_act_fn(config.activation_function)
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134
135
136
137
138
139

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

140
141
142
    def __init__(
        self,
        config: GPT2Config,
143
        cache_config: Optional[CacheConfig] = None,
144
        quant_config: Optional[QuantizationConfig] = None,
145
        prefix: str = "",
146
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
        super().__init__()
        hidden_size = config.hidden_size
149
150
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
153
154
155
156
        self.attn = GPT2Attention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
158
159
160
161
        self.mlp = GPT2MLP(inner_dim,
                           config,
                           quant_config,
                           prefix=f"{prefix}.mlp")
Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
165
166
167
168

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
169
        attn_output = self.attn(hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172
173
174
175
176
177
178
179
180
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


181
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
182
183
class GPT2Model(nn.Module):

184
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
185
        super().__init__()
186
187
188
189
190

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
191
        self.config = config
192
193
194
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
195
        self.embed_dim = config.hidden_size
196
197
198
199
        self.wte = VocabParallelEmbedding(config.vocab_size,
                                          self.embed_dim,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.wte")
Woosuk Kwon's avatar
Woosuk Kwon committed
200
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
201
        self.start_layer, self.end_layer, self.h = make_layers(
202
            config.num_hidden_layers,
203
204
205
            lambda prefix: GPT2Block(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
206
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
207
208
209
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
Woosuk Kwon's avatar
Woosuk Kwon committed
210

211
212
213
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
    def forward(
        self,
216
217
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
218
        intermediate_tensors: Optional[IntermediateTensors],
219
        inputs_embeds: Optional[torch.Tensor],
220
221
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
222
            if inputs_embeds is None:
223
                inputs_embeds = self.get_input_embeddings(input_ids)
224
225
226
227
228
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Woosuk Kwon's avatar
Woosuk Kwon committed
229

230
231
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states)
232
233
234

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238

        hidden_states = self.ln_f(hidden_states)
        return hidden_states

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Woosuk Kwon's avatar
Woosuk Kwon committed
268

269
class GPT2LMHeadModel(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
270

271
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
272
        super().__init__()
273
274
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
275
        self.config = config
276
        self.quant_config = quant_config
277
278
279
        self.transformer = GPT2Model(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
280
281
282
283
        self.lm_head = ParallelLMHead(self.config.vocab_size,
                                      self.config.hidden_size,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.lm_head")
284
        if self.config.tie_word_embeddings:
285
286
            self.lm_head = self.lm_head.tie_weights(self.transformer.wte)

287
        self.logits_processor = LogitsProcessor(config.vocab_size)
288
289
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
290

291
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
292
        return self.transformer.get_input_embeddings(input_ids)
293

Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
    def forward(
        self,
296
297
        input_ids: torch.Tensor,
        positions: torch.Tensor,
298
        intermediate_tensors: Optional[IntermediateTensors] = None,
299
        inputs_embeds: Optional[torch.Tensor] = None,
300
    ) -> Union[torch.Tensor, IntermediateTensors]:
301
302
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
303
304
        return hidden_states

305
306
307
308
309
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
310
        logits = self.logits_processor(self.lm_head, hidden_states,
311
312
313
                                       sampling_metadata)
        return logits

314
315
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
316
317
318
319
320
321
322
323
324
325
326
327
328
        loader = AutoWeightsLoader(self)
        weights = _add_transformer_prefix(weights)
        return loader.load_weights(weights)


def _add_transformer_prefix(
    weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
    for name, tensor in weights:
        if not name.startswith('transformer.') and not name.startswith(
                "lm_head"):
            name = 'transformer.' + name
        yield name, tensor