gpt_j.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-J model compatible with HuggingFace weights."""
21
from collections.abc import Iterable
22
from itertools import islice
23
from typing import Optional, Union
24
25
26
27
28

import torch
from torch import nn
from transformers import GPTJConfig

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.activation import get_act_fn
34
35
36
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.rotary_embedding import get_rope
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
43
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
44
from vllm.sequence import IntermediateTensors
45

46
from .interfaces import SupportsPP
47
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
48
49
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
50

51
52
53

class GPTJAttention(nn.Module):

54
55
56
    def __init__(
        self,
        config: GPTJConfig,
57
        cache_config: Optional[CacheConfig] = None,
58
        quant_config: Optional[QuantizationConfig] = None,
59
        prefix: str = "",
60
    ):
61
62
63
64
65
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

66
        self.qkv_proj = QKVParallelLinear(
67
            config.hidden_size,
68
69
            self.head_size,
            self.total_num_heads,
70
            bias=False,
71
            quant_config=quant_config,
72
73
74
75
76
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
77
            quant_config=quant_config,
78
        )
79
80
81
82
83
84

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
85
        assert getattr(config, "rotary", True)
86
        assert config.rotary_dim % 2 == 0
87
88
89
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
90
        self.rotary_emb = get_rope(
91
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
            rotary_dim=config.rotary_dim,
93
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
            base=rope_theta,
            is_neox_style=False,
        )
97
98
99
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
100
                              cache_config=cache_config,
101
102
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
103
104
105
106
107
108
109
110

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        q, k = self.rotary_emb(position_ids, q, k)
112
        attn_output = self.attn(q, k, v)
113
114
115
116
117
118
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

119
120
121
122
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
123
        quant_config: Optional[QuantizationConfig] = None,
124
    ):
125
126
        super().__init__()
        hidden_size = config.n_embd
127
128
129
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
130
            quant_config=quant_config,
131
132
133
134
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
135
            quant_config=quant_config,
136
        )
137
        self.act = get_act_fn(config.activation_function)
138
139
140
141
142
143
144
145
146
147

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

148
149
150
    def __init__(
        self,
        config: GPTJConfig,
151
        cache_config: Optional[CacheConfig] = None,
152
        quant_config: Optional[QuantizationConfig] = None,
153
        prefix: str = "",
154
    ):
155
        super().__init__()
156
157
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
158
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
159
160
161
162
        self.attn = GPTJAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
163
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


181
@support_torch_compile
182
183
class GPTJModel(nn.Module):

184
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
185
        super().__init__()
186
187
188
189
190

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

191
        self.config = config
192
        self.quant_config = quant_config
193
        self.embed_dim = config.n_embd
194
195
196
197
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
198
199
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
200
201
            lambda prefix: GPTJBlock(
                config, cache_config, quant_config, prefix=prefix),
202
203
            prefix=f"{prefix}.h",
        )
204
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
205
206
207
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
208

209
210
211
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

212
213
214
215
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
216
        intermediate_tensors: Optional[IntermediateTensors],
217
        inputs_embeds: Optional[torch.Tensor] = None,
218
219
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
220
221
222
223
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
224
225
        else:
            hidden_states = intermediate_tensors["hidden_states"]
226
        for layer in islice(self.h, self.start_layer, self.end_layer):
227
            hidden_states = layer(position_ids, hidden_states)
228
229
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
230
231
232
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

233
234
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
235
236
237
238
239
240
241
242
243
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
244
        loaded_params: set[str] = set()
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        for name, loaded_weight in weights:
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

290

291
class GPTJForCausalLM(nn.Module, SupportsPP):
292

293
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
294
        super().__init__()
295
296
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
297
        self.config = config
298
        self.quant_config = quant_config
299
        assert not config.tie_word_embeddings
300
301
302
        self.transformer = GPTJModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
303
        self.lm_head = ParallelLMHead(
304
            config.vocab_size,
305
306
            config.n_embd,
            bias=True,
307
            quant_config=quant_config,
308
            prefix=maybe_prefix(prefix, "lm_head"),
309
        )
310
        self.logits_processor = LogitsProcessor(config.vocab_size)
311
312
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
313

314
315
316
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

317
318
319
320
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
321
        intermediate_tensors: Optional[IntermediateTensors] = None,
322
        inputs_embeds: Optional[torch.Tensor] = None,
323
    ) -> Union[torch.Tensor, IntermediateTensors]:
324
325
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
326
327
        return hidden_states

328
329
330
331
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
332
        logits = self.logits_processor(self.lm_head, hidden_states,
333
                                       self.lm_head.bias)
334
335
        return logits

336
337
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
338
        loader = AutoWeightsLoader(self)
339
        return loader.load_weights(weights)