gpt2.py 12.8 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
24

import torch
from torch import nn
from transformers import GPT2Config

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.compilation.decorators import support_torch_compile
27
from vllm.config import CacheConfig, VllmConfig
28
29
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
36
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
45
46
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
47

Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50

class GPT2Attention(nn.Module):

51
52
53
    def __init__(
        self,
        config: GPT2Config,
54
        cache_config: Optional[CacheConfig] = None,
55
        quant_config: Optional[QuantizationConfig] = None,
56
        prefix: str = "",
57
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
61
62
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
66
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
        self.c_attn = QKVParallelLinear(
69
            self.hidden_size,
70
71
            self.head_dim,
            total_num_heads,
72
            bias=True,
73
            quant_config=quant_config,
74
            prefix=f"{prefix}.c_attn",
75
76
77
78
79
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
80
            quant_config=quant_config,
81
            prefix=f"{prefix}.c_proj",
82
        )
83
84
85
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
86
                              cache_config=cache_config,
87
88
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92

    def forward(
        self,
        hidden_states: torch.Tensor,
93
94
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
98
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102
103
104
105
106
107
108
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
109
        quant_config: Optional[QuantizationConfig] = None,
110
        prefix: str = "",
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
    ):
        super().__init__()
        hidden_size = config.hidden_size
114
115
116
117
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
118
            quant_config=quant_config,
119
            prefix=f"{prefix}.c_fc",
120
121
122
123
124
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
125
            quant_config=quant_config,
126
            prefix=f"{prefix}.c_proj",
127
        )
128
        self.act = get_act_fn(config.activation_function)
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
132
133
134
135
136
137
138

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

139
140
141
    def __init__(
        self,
        config: GPT2Config,
142
        cache_config: Optional[CacheConfig] = None,
143
        quant_config: Optional[QuantizationConfig] = None,
144
        prefix: str = "",
145
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
        super().__init__()
        hidden_size = config.hidden_size
148
149
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
152
153
154
155
        self.attn = GPT2Attention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
157
158
159
160
        self.mlp = GPT2MLP(inner_dim,
                           config,
                           quant_config,
                           prefix=f"{prefix}.mlp")
Woosuk Kwon's avatar
Woosuk Kwon committed
161
162
163
164

    def forward(
        self,
        hidden_states: torch.Tensor,
165
166
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
167
168
169
170
171
172
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
173
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
178
179
180
181
182
183
184
185
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


186
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
187
188
class GPT2Model(nn.Module):

189
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
190
        super().__init__()
191
192
193
194
195

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
196
        self.config = config
197
198
199
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
200
        self.embed_dim = config.hidden_size
201
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
202
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
203
        self.start_layer, self.end_layer, self.h = make_layers(
204
            config.num_hidden_layers,
205
206
207
            lambda prefix: GPT2Block(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
208
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
209
210
211
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
Woosuk Kwon's avatar
Woosuk Kwon committed
212

213
214
215
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
    def forward(
        self,
218
219
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
220
221
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
222
        intermediate_tensors: Optional[IntermediateTensors],
223
        inputs_embeds: Optional[torch.Tensor],
224
225
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
226
            if inputs_embeds is None:
227
                inputs_embeds = self.get_input_embeddings(input_ids)
228
229
230
231
232
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Woosuk Kwon's avatar
Woosuk Kwon committed
233

234
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
235
            layer = self.h[i]
236
237
238
239
240
241
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
244
245
246

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


247
class GPT2LMHeadModel(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
248

249
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
250
        super().__init__()
251
252
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
253
        self.config = config
254
        self.quant_config = quant_config
255
256
257
        self.transformer = GPT2Model(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
258
259
260
261
262
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)
263
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
264
        self.sampler = get_sampler()
265
266
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
267

268
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
269
        return self.transformer.get_input_embeddings(input_ids)
270

Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
    def forward(
        self,
273
274
        input_ids: torch.Tensor,
        positions: torch.Tensor,
275
276
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
277
        intermediate_tensors: Optional[IntermediateTensors] = None,
278
        inputs_embeds: Optional[torch.Tensor] = None,
279
    ) -> Union[torch.Tensor, IntermediateTensors]:
280
        hidden_states = self.transformer(input_ids, positions, kv_caches,
281
282
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
283
284
        return hidden_states

285
286
287
288
289
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
290
        logits = self.logits_processor(self.lm_head, hidden_states,
291
292
293
                                       sampling_metadata)
        return logits

294
295
    def sample(
        self,
296
        logits: torch.Tensor,
297
        sampling_metadata: SamplingMetadata,
298
    ) -> Optional[SamplerOutput]:
299
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
300
301
        return next_tokens

302
303
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
304
        params_dict = dict(self.named_parameters(remove_duplicate=False))
305
        loaded_params: Set[str] = set()
306
        for name, loaded_weight in weights:
Woosuk Kwon's avatar
Woosuk Kwon committed
307
308
309
310
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
311
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
312
313
314
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
315
316
            if not name.startswith("transformer."):
                name = "transformer." + name
317
318

            if is_pp_missing_parameter(name, self):
319
                continue
320
321
322
323
324
325
326
327
328
329
330
331
332
333

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
334
335
            loaded_params.add(name)
        return loaded_params