"vllm/model_executor/models/solar.py" did not exist on "7025b11d949b4efeb2584690c35f919c77027368"
gpt2.py 12.2 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
24

import torch
from torch import nn
from transformers import GPT2Config

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.compilation.decorators import support_torch_compile
27
from vllm.config import CacheConfig
28
29
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
45
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
46

Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49

class GPT2Attention(nn.Module):

50
51
52
    def __init__(
        self,
        config: GPT2Config,
53
        cache_config: Optional[CacheConfig] = None,
54
        quant_config: Optional[QuantizationConfig] = None,
55
        prefix: str = "",
56
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
59
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
60
61
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
65
        self.scale = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
66

67
        self.c_attn = QKVParallelLinear(
68
            self.hidden_size,
69
70
            self.head_dim,
            total_num_heads,
71
            bias=True,
72
            quant_config=quant_config,
73
            prefix=f"{prefix}.c_attn",
74
75
76
77
78
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
79
            quant_config=quant_config,
80
            prefix=f"{prefix}.c_proj",
81
        )
82
83
84
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
85
86
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90

    def forward(
        self,
        hidden_states: torch.Tensor,
91
92
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
96
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
101
102
103
104
105
106
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class GPT2MLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: GPT2Config,
107
        quant_config: Optional[QuantizationConfig] = None,
108
        prefix: str = "",
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
111
    ):
        super().__init__()
        hidden_size = config.hidden_size
112
113
114
115
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
116
            quant_config=quant_config,
117
            prefix=f"{prefix}.c_fc",
118
119
120
121
122
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
123
            quant_config=quant_config,
124
            prefix=f"{prefix}.c_proj",
125
        )
126
127
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
130
131
132
133
134
135
136
137

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):

138
139
140
    def __init__(
        self,
        config: GPT2Config,
141
        cache_config: Optional[CacheConfig] = None,
142
        quant_config: Optional[QuantizationConfig] = None,
143
        prefix: str = "",
144
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
        super().__init__()
        hidden_size = config.hidden_size
147
148
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
151
152
153
154
        self.attn = GPT2Attention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
155
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
157
158
159
        self.mlp = GPT2MLP(inner_dim,
                           config,
                           quant_config,
                           prefix=f"{prefix}.mlp")
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
162
163

    def forward(
        self,
        hidden_states: torch.Tensor,
164
165
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168
169
170
171
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
172
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
175
176
177
178
179
180
181
182
183
184
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


185
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
class GPT2Model(nn.Module):

188
189
190
    def __init__(
        self,
        config: GPT2Config,
191
        cache_config: Optional[CacheConfig] = None,
192
        quant_config: Optional[QuantizationConfig] = None,
193
        prefix: str = "",
194
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
        super().__init__()
        self.config = config
197
198
199
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
Woosuk Kwon's avatar
Woosuk Kwon committed
200
        self.embed_dim = config.hidden_size
201
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
Woosuk Kwon's avatar
Woosuk Kwon committed
202
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
203
        self.start_layer, self.end_layer, self.h = make_layers(
204
            config.num_hidden_layers,
205
206
207
            lambda prefix: GPT2Block(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
208
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
209
210
211
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214

    def forward(
        self,
215
216
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
217
218
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
219
220
221
222
223
224
225
226
227
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            inputs_embeds = self.wte(input_ids)
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Woosuk Kwon's avatar
Woosuk Kwon committed
228

229
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
230
            layer = self.h[i]
231
232
233
234
235
236
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239
240
241

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


242
class GPT2LMHeadModel(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
243

244
245
246
    def __init__(
        self,
        config: GPT2Config,
247
        cache_config: Optional[CacheConfig] = None,
248
        quant_config: Optional[QuantizationConfig] = None,
249
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
        super().__init__()
        self.config = config
252
        self.quant_config = quant_config
253
254
255
256
        self.transformer = GPT2Model(config,
                                     cache_config,
                                     quant_config,
                                     prefix="transformer")
257
258
259
260
261
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)
262
263
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
264
265
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
268

    def forward(
        self,
269
270
        input_ids: torch.Tensor,
        positions: torch.Tensor,
271
272
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
273
        intermediate_tensors: Optional[IntermediateTensors] = None,
274
    ) -> Union[torch.Tensor, IntermediateTensors]:
275
        hidden_states = self.transformer(input_ids, positions, kv_caches,
276
                                         attn_metadata, intermediate_tensors)
277
278
        return hidden_states

279
280
281
282
283
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
284
        logits = self.logits_processor(self.lm_head, hidden_states,
285
286
287
                                       sampling_metadata)
        return logits

288
289
    def sample(
        self,
290
        logits: torch.Tensor,
291
        sampling_metadata: SamplingMetadata,
292
    ) -> Optional[SamplerOutput]:
293
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
        return next_tokens

296
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
297
        params_dict = dict(self.named_parameters(remove_duplicate=False))
298
        for name, loaded_weight in weights:
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
301
302
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
303
            if ".attn.bias" in name or ".attn.masked_bias" in name:
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
306
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
307
308
            if not name.startswith("transformer."):
                name = "transformer." + name
309
310

            if is_pp_missing_parameter(name, self):
311
                continue
312
313
314
315
316
317
318
319
320
321
322
323
324
325

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)