gpt_j.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-J model compatible with HuggingFace weights."""
21

22
from collections.abc import Iterable
23
from itertools import islice
24
from typing import Optional, Union
25
26
27
28
29

import torch
from torch import nn
from transformers import GPTJConfig

30
from vllm.attention import Attention
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
from vllm.model_executor.layers.activation import get_act_fn
35
36
37
38
39
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.quantization import QuantizationConfig
42
from vllm.model_executor.layers.rotary_embedding import get_rope
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
44
45
46
    ParallelLMHead,
    VocabParallelEmbedding,
)
47
from vllm.model_executor.model_loader.weight_utils import (
48
49
50
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
51
from vllm.sequence import IntermediateTensors
52

53
from .interfaces import SupportsPP
54
55
56
57
58
59
60
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
61

62
63

class GPTJAttention(nn.Module):
64
65
66
    def __init__(
        self,
        config: GPTJConfig,
67
        cache_config: Optional[CacheConfig] = None,
68
        quant_config: Optional[QuantizationConfig] = None,
69
        prefix: str = "",
70
    ):
71
72
73
74
75
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

76
        self.qkv_proj = QKVParallelLinear(
77
            config.hidden_size,
78
79
            self.head_size,
            self.total_num_heads,
80
            bias=False,
81
            quant_config=quant_config,
82
83
84
85
86
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
87
            quant_config=quant_config,
88
        )
89
90
91
92
93
94

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
95
        assert getattr(config, "rotary", True)
96
        assert config.rotary_dim % 2 == 0
97
        rope_theta = getattr(config, "rope_theta", 10000)
98
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        self.rotary_emb = get_rope(
100
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
101
            rotary_dim=config.rotary_dim,
102
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
105
            base=rope_theta,
            is_neox_style=False,
        )
106
107
108
109
110
111
112
113
        self.attn = Attention(
            self.num_heads,
            self.head_size,
            scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
114
115
116
117
118
119
120
121

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
122
        q, k = self.rotary_emb(position_ids, q, k)
123
        attn_output = self.attn(q, k, v)
124
125
126
127
128
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):
129
130
131
132
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
133
        quant_config: Optional[QuantizationConfig] = None,
134
    ):
135
136
        super().__init__()
        hidden_size = config.n_embd
137
138
139
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
140
            quant_config=quant_config,
141
142
143
144
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
145
            quant_config=quant_config,
146
        )
147
        self.act = get_act_fn(config.activation_function)
148
149
150
151
152
153
154
155
156

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):
157
158
159
    def __init__(
        self,
        config: GPTJConfig,
160
        cache_config: Optional[CacheConfig] = None,
161
        quant_config: Optional[QuantizationConfig] = None,
162
        prefix: str = "",
163
    ):
164
        super().__init__()
165
        inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner
166
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
167
168
169
        self.attn = GPTJAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
170
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


188
@support_torch_compile
189
class GPTJModel(nn.Module):
190
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
191
        super().__init__()
192
193
194
195
196

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

197
        self.config = config
198
        self.quant_config = quant_config
199
        self.embed_dim = config.n_embd
200
201
202
203
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
204
205
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
206
            lambda prefix: GPTJBlock(config, cache_config, quant_config, prefix=prefix),
207
208
            prefix=f"{prefix}.h",
        )
209
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
210
211
212
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.n_embd
        )
213

214
215
216
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

217
218
219
220
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
221
        intermediate_tensors: Optional[IntermediateTensors],
222
        inputs_embeds: Optional[torch.Tensor] = None,
223
224
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
225
226
227
228
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
229
230
        else:
            hidden_states = intermediate_tensors["hidden_states"]
231
        for layer in islice(self.h, self.start_layer, self.end_layer):
232
            hidden_states = layer(position_ids, hidden_states)
233
234
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
235
236
237
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

238
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
239
240
241
242
243
244
245
246
247
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
248
        loaded_params: set[str] = set()
249
250
251
252
        for name, loaded_weight in weights:
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue

253
254
255
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
256
257
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
258
259
260
261
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
262
263
264
265
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

266
            for param_name, weight_name, shard_id in stacked_params_mapping:
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
289
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
290
291
292
293
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

294

295
class GPTJForCausalLM(nn.Module, SupportsPP):
296
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
297
        super().__init__()
298
299
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
300
        self.config = config
301
        self.quant_config = quant_config
302
        assert not config.tie_word_embeddings
303
304
305
        self.transformer = GPTJModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
306
        self.lm_head = ParallelLMHead(
307
            config.vocab_size,
308
309
            config.n_embd,
            bias=True,
310
            quant_config=quant_config,
311
            prefix=maybe_prefix(prefix, "lm_head"),
312
        )
313
        self.logits_processor = LogitsProcessor(config.vocab_size)
314
        self.make_empty_intermediate_tensors = (
315
316
            self.transformer.make_empty_intermediate_tensors
        )
317

318
319
320
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

321
322
323
324
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
325
        intermediate_tensors: Optional[IntermediateTensors] = None,
326
        inputs_embeds: Optional[torch.Tensor] = None,
327
    ) -> Union[torch.Tensor, IntermediateTensors]:
328
329
330
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
331
332
        return hidden_states

333
334
335
336
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
337
        logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias)
338
339
        return logits

340
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
341
        loader = AutoWeightsLoader(self)
342
        return loader.load_weights(weights)