gpt2.py 13.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25
26

import torch
from torch import nn
from transformers import GPT2Config

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, VllmConfig
30
31
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.activation import get_act_fn
33
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
38
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
44

45
46
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
49

Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52

class GPT2Attention(nn.Module):

53
54
55
    def __init__(
        self,
        config: GPT2Config,
56
        cache_config: Optional[CacheConfig] = None,
57
        quant_config: Optional[QuantizationConfig] = None,
58
        prefix: str = "",
59
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
63
64
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
68
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
69

70
        self.c_attn = QKVParallelLinear(
71
            self.hidden_size,
72
73
            self.head_dim,
            total_num_heads,
74
            bias=True,
75
            quant_config=quant_config,
76
            prefix=f"{prefix}.c_attn",
77
78
79
80
81
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
82
            quant_config=quant_config,
83
            prefix=f"{prefix}.c_proj",
84
        )
85
86
87
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
88
                              cache_config=cache_config,
89
90
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94

    def forward(
        self,
        hidden_states: torch.Tensor,
95
96
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
100
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
106
107
108
109
110
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
111
        quant_config: Optional[QuantizationConfig] = None,
112
        prefix: str = "",
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
    ):
        super().__init__()
        hidden_size = config.hidden_size
116
117
118
119
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
120
            quant_config=quant_config,
121
            prefix=f"{prefix}.c_fc",
122
123
124
125
126
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
127
            quant_config=quant_config,
128
            prefix=f"{prefix}.c_proj",
129
        )
130
        self.act = get_act_fn(config.activation_function)
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
135
136
137
138
139
140

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

141
142
143
    def __init__(
        self,
        config: GPT2Config,
144
        cache_config: Optional[CacheConfig] = None,
145
        quant_config: Optional[QuantizationConfig] = None,
146
        prefix: str = "",
147
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
        super().__init__()
        hidden_size = config.hidden_size
150
151
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
154
155
156
157
        self.attn = GPT2Attention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
158
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
159
160
161
162
        self.mlp = GPT2MLP(inner_dim,
                           config,
                           quant_config,
                           prefix=f"{prefix}.mlp")
Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
166

    def forward(
        self,
        hidden_states: torch.Tensor,
167
168
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171
172
173
174
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
175
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178
179
180
181
182
183
184
185
186
187
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


188
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
class GPT2Model(nn.Module):

191
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
192
        super().__init__()
193
194
195
196
197

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
198
        self.config = config
199
200
201
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
202
        self.embed_dim = config.hidden_size
203
204
205
206
        self.wte = VocabParallelEmbedding(config.vocab_size,
                                          self.embed_dim,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.wte")
Woosuk Kwon's avatar
Woosuk Kwon committed
207
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
208
        self.start_layer, self.end_layer, self.h = make_layers(
209
            config.num_hidden_layers,
210
211
212
            lambda prefix: GPT2Block(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
214
215
216
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
Woosuk Kwon's avatar
Woosuk Kwon committed
217

218
219
220
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
    def forward(
        self,
223
224
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
225
226
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
227
        intermediate_tensors: Optional[IntermediateTensors],
228
        inputs_embeds: Optional[torch.Tensor],
229
230
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
231
            if inputs_embeds is None:
232
                inputs_embeds = self.get_input_embeddings(input_ids)
233
234
235
236
237
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Woosuk Kwon's avatar
Woosuk Kwon committed
238

239
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
240
            layer = self.h[i]
241
242
243
244
245
246
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
250
251

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


252
class GPT2LMHeadModel(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
253

254
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
255
        super().__init__()
256
257
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
258
        self.config = config
259
        self.quant_config = quant_config
260
261
262
        self.transformer = GPT2Model(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
263
264
265
266
        self.lm_head = ParallelLMHead(self.config.vocab_size,
                                      self.config.hidden_size,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.lm_head")
267
        if self.config.tie_word_embeddings:
268
269
            self.lm_head = self.lm_head.tie_weights(self.transformer.wte)

270
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
271
        self.sampler = get_sampler()
272
273
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
274

275
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
276
        return self.transformer.get_input_embeddings(input_ids)
277

Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
    def forward(
        self,
280
281
        input_ids: torch.Tensor,
        positions: torch.Tensor,
282
283
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
284
        intermediate_tensors: Optional[IntermediateTensors] = None,
285
        inputs_embeds: Optional[torch.Tensor] = None,
286
    ) -> Union[torch.Tensor, IntermediateTensors]:
287
        hidden_states = self.transformer(input_ids, positions, kv_caches,
288
289
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
290
291
        return hidden_states

292
293
294
295
296
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
297
        logits = self.logits_processor(self.lm_head, hidden_states,
298
299
300
                                       sampling_metadata)
        return logits

301
302
    def sample(
        self,
303
        logits: torch.Tensor,
304
        sampling_metadata: SamplingMetadata,
305
    ) -> Optional[SamplerOutput]:
306
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
307
308
        return next_tokens

309
310
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
311
        params_dict = dict(self.named_parameters(remove_duplicate=False))
312
        loaded_params: Set[str] = set()
313
        for name, loaded_weight in weights:
314
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
317
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
318
319
            if not name.startswith("transformer.") and not name.startswith(
                    "lm_head"):
320
                name = "transformer." + name
321
322

            if is_pp_missing_parameter(name, self):
323
                continue
324
325
326
327
328
329
330
331
332
333
334
335
336
337

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
338
339
            loaded_params.add(name)
        return loaded_params