gpt2.py 11.1 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25

import torch
from torch import nn
from transformers import GPT2Config

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.config import CacheConfig
28
29
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.sampler import Sampler
38
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors, SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
45
from .utils import is_pp_missing_parameter, make_layers

Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48

class GPT2Attention(nn.Module):

49
50
51
    def __init__(
        self,
        config: GPT2Config,
52
        cache_config: Optional[CacheConfig] = None,
53
        quant_config: Optional[QuantizationConfig] = None,
54
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
58
59
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
63
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
64

65
        self.c_attn = QKVParallelLinear(
66
            self.hidden_size,
67
68
            self.head_dim,
            total_num_heads,
69
            bias=True,
70
            quant_config=quant_config,
71
72
73
74
75
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
76
            quant_config=quant_config,
77
        )
78
79
80
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
81
82
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
86

    def forward(
        self,
        hidden_states: torch.Tensor,
87
88
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
92
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
96
97
98
99
100
101
102
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
103
        quant_config: Optional[QuantizationConfig] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
    ):
        super().__init__()
        hidden_size = config.hidden_size
107
108
109
110
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
111
            quant_config=quant_config,
112
113
114
115
116
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
117
            quant_config=quant_config,
118
        )
119
120
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
125
126
127
128
129
130

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

131
132
133
    def __init__(
        self,
        config: GPT2Config,
134
        cache_config: Optional[CacheConfig] = None,
135
        quant_config: Optional[QuantizationConfig] = None,
136
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
        super().__init__()
        hidden_size = config.hidden_size
139
140
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
141
142

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
143
        self.attn = GPT2Attention(config, cache_config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
145
        self.mlp = GPT2MLP(inner_dim, config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
149

    def forward(
        self,
        hidden_states: torch.Tensor,
150
151
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
155
156
157
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
158
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class GPT2Model(nn.Module):

173
174
175
    def __init__(
        self,
        config: GPT2Config,
176
        cache_config: Optional[CacheConfig] = None,
177
        quant_config: Optional[QuantizationConfig] = None,
178
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
        super().__init__()
        self.config = config
181
182
183
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
184
        self.embed_dim = config.hidden_size
185
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
186
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
187
        self.start_layer, self.end_layer, self.h = make_layers(
188
            config.num_hidden_layers,
189
            lambda: GPT2Block(config, cache_config, quant_config))
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192
193
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
194
195
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
196
197
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
198
199
200
201
202
203
204
205
206
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            inputs_embeds = self.wte(input_ids)
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Woosuk Kwon's avatar
Woosuk Kwon committed
207

208
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
209
            layer = self.h[i]
210
211
212
213
214
215
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
219
220
221
222

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPT2LMHeadModel(nn.Module):

223
224
225
    def __init__(
        self,
        config: GPT2Config,
226
        cache_config: Optional[CacheConfig] = None,
227
        quant_config: Optional[QuantizationConfig] = None,
228
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
229
230
        super().__init__()
        self.config = config
231
        self.quant_config = quant_config
232
        self.transformer = GPT2Model(config, cache_config, quant_config)
233
        self.lm_head = self.transformer.wte
234
235
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238

    def forward(
        self,
239
240
        input_ids: torch.Tensor,
        positions: torch.Tensor,
241
242
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
243
        intermediate_tensors: Optional[IntermediateTensors] = None,
244
    ) -> torch.Tensor:
245
        hidden_states = self.transformer(input_ids, positions, kv_caches,
246
                                         attn_metadata, intermediate_tensors)
247
248
        return hidden_states

249
250
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
251
        logits = self.logits_processor(self.lm_head, hidden_states,
252
253
254
                                       sampling_metadata)
        return logits

255
256
    def sample(
        self,
257
        logits: torch.Tensor,
258
        sampling_metadata: SamplingMetadata,
259
    ) -> Optional[SamplerOutput]:
260
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
        return next_tokens

263
264
265
266
267
268
269
270
271
272
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

273
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
274
        params_dict = dict(self.named_parameters(remove_duplicate=False))
275
        for name, loaded_weight in weights:
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278
279
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
280
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
283
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
284
285
            if not name.startswith("transformer."):
                name = "transformer." + name
286
287

            if is_pp_missing_parameter(name, self):
288
                continue
289
290
291
292
293
294
295
296
297
298
299
300
301
302

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)