gpt2.py 12.2 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
24

import torch
from torch import nn
from transformers import GPT2Config

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.compilation.decorators import support_torch_compile
27
from vllm.config import CacheConfig, VllmConfig
28
29
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
36
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
45
46
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
47

Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50

class GPT2Attention(nn.Module):

51
52
53
    def __init__(
        self,
        config: GPT2Config,
54
        cache_config: Optional[CacheConfig] = None,
55
        quant_config: Optional[QuantizationConfig] = None,
56
        prefix: str = "",
57
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
61
62
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
66
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
        self.c_attn = QKVParallelLinear(
69
            self.hidden_size,
70
71
            self.head_dim,
            total_num_heads,
72
            bias=True,
73
            quant_config=quant_config,
74
            prefix=f"{prefix}.c_attn",
75
76
77
78
79
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
80
            quant_config=quant_config,
81
            prefix=f"{prefix}.c_proj",
82
        )
83
84
85
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
86
87
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
91

    def forward(
        self,
        hidden_states: torch.Tensor,
92
93
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
97
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103
104
105
106
107
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
108
        quant_config: Optional[QuantizationConfig] = None,
109
        prefix: str = "",
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
    ):
        super().__init__()
        hidden_size = config.hidden_size
113
114
115
116
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
117
            quant_config=quant_config,
118
            prefix=f"{prefix}.c_fc",
119
120
121
122
123
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
124
            quant_config=quant_config,
125
            prefix=f"{prefix}.c_proj",
126
        )
127
        self.act = get_act_fn(config.activation_function)
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
130
131
132
133
134
135
136
137

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

138
139
140
    def __init__(
        self,
        config: GPT2Config,
141
        cache_config: Optional[CacheConfig] = None,
142
        quant_config: Optional[QuantizationConfig] = None,
143
        prefix: str = "",
144
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
        super().__init__()
        hidden_size = config.hidden_size
147
148
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
151
152
153
154
        self.attn = GPT2Attention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
155
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
157
158
159
        self.mlp = GPT2MLP(inner_dim,
                           config,
                           quant_config,
                           prefix=f"{prefix}.mlp")
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
162
163

    def forward(
        self,
        hidden_states: torch.Tensor,
164
165
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168
169
170
171
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
172
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
175
176
177
178
179
180
181
182
183
184
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


185
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
class GPT2Model(nn.Module):

188
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
189
        super().__init__()
190
191
192
193
194

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
195
        self.config = config
196
197
198
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
199
        self.embed_dim = config.hidden_size
200
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
201
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
202
        self.start_layer, self.end_layer, self.h = make_layers(
203
            config.num_hidden_layers,
204
205
206
            lambda prefix: GPT2Block(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
207
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
208
209
210
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
213

    def forward(
        self,
214
215
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
216
217
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
218
219
220
221
222
223
224
225
226
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            inputs_embeds = self.wte(input_ids)
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Woosuk Kwon's avatar
Woosuk Kwon committed
227

228
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
229
            layer = self.h[i]
230
231
232
233
234
235
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238
239
240

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


241
class GPT2LMHeadModel(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
242

243
244
    def __init__(
        self,
245
246
        vllm_config: VllmConfig,
        prefix: str = "",
247
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
248
        super().__init__()
249
250
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
251
        self.config = config
252
        self.quant_config = quant_config
253
254
255
        self.transformer = GPT2Model(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
256
257
258
259
260
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)
261
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
262
        self.sampler = get_sampler()
263
264
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
267

    def forward(
        self,
268
269
        input_ids: torch.Tensor,
        positions: torch.Tensor,
270
271
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
272
        intermediate_tensors: Optional[IntermediateTensors] = None,
273
    ) -> Union[torch.Tensor, IntermediateTensors]:
274
        hidden_states = self.transformer(input_ids, positions, kv_caches,
275
                                         attn_metadata, intermediate_tensors)
276
277
        return hidden_states

278
279
280
281
282
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
283
        logits = self.logits_processor(self.lm_head, hidden_states,
284
285
286
                                       sampling_metadata)
        return logits

287
288
    def sample(
        self,
289
        logits: torch.Tensor,
290
        sampling_metadata: SamplingMetadata,
291
    ) -> Optional[SamplerOutput]:
292
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
        return next_tokens

295
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
296
        params_dict = dict(self.named_parameters(remove_duplicate=False))
297
        for name, loaded_weight in weights:
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300
301
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
302
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
306
307
            if not name.startswith("transformer."):
                name = "transformer." + name
308
309

            if is_pp_missing_parameter(name, self):
310
                continue
311
312
313
314
315
316
317
318
319
320
321
322
323
324

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)