gpt_j.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-J model compatible with HuggingFace weights."""
21

22
from collections.abc import Iterable
23
from itertools import islice
24
25
26
27
28

import torch
from torch import nn
from transformers import GPTJConfig

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.activation import get_act_fn
34
35
36
37
38
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
44
45
    ParallelLMHead,
    VocabParallelEmbedding,
)
46
from vllm.model_executor.model_loader.weight_utils import (
47
48
49
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
50
from vllm.sequence import IntermediateTensors
51

52
from .interfaces import SupportsPP
53
54
55
56
57
58
59
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
60

61
62

class GPTJAttention(nn.Module):
63
64
65
    def __init__(
        self,
        config: GPTJConfig,
66
67
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
68
        prefix: str = "",
69
    ):
70
71
72
73
74
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

75
        self.qkv_proj = QKVParallelLinear(
76
            config.hidden_size,
77
78
            self.head_size,
            self.total_num_heads,
79
            bias=False,
80
            quant_config=quant_config,
81
            prefix=f"{prefix}.qkv_proj",
82
83
84
85
86
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
87
            quant_config=quant_config,
88
            prefix=f"{prefix}.out_proj",
89
        )
90
91
92
93
94
95

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
96
        assert getattr(config, "rotary", True)
97
        assert config.rotary_dim % 2 == 0
98
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        self.rotary_emb = get_rope(
100
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
101
            rotary_dim=config.rotary_dim,
102
            max_position=max_position_embeddings,
103
            rope_parameters=getattr(config, "rope_parameters", None),
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
            is_neox_style=False,
        )
106
107
108
109
110
111
112
113
        self.attn = Attention(
            self.num_heads,
            self.head_size,
            scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
114
115
116
117
118
119
120
121

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
122
        q, k = self.rotary_emb(position_ids, q, k)
123
        attn_output = self.attn(q, k, v)
124
125
126
127
128
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):
129
130
131
132
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
133
        quant_config: QuantizationConfig | None = None,
134
        prefix: str = "",
135
    ):
136
137
        super().__init__()
        hidden_size = config.n_embd
138
139
140
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
141
            quant_config=quant_config,
142
            prefix=f"{prefix}.fc_in",
143
144
145
146
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
147
            quant_config=quant_config,
148
            prefix=f"{prefix}.fc_out",
149
        )
150
        self.act = get_act_fn(config.activation_function)
151
152
153
154
155
156
157
158
159

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):
160
161
162
    def __init__(
        self,
        config: GPTJConfig,
163
164
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
165
        prefix: str = "",
166
    ):
167
        super().__init__()
168
        inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner
169
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
170
171
172
        self.attn = GPTJAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
173
        self.mlp = GPTJMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


191
@support_torch_compile
192
class GPTJModel(nn.Module):
193
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
194
        super().__init__()
195
196
197
198
199

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

200
        self.config = config
201
        self.quant_config = quant_config
202
        self.embed_dim = config.n_embd
203
204
205
206
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
207
208
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
209
            lambda prefix: GPTJBlock(config, cache_config, quant_config, prefix=prefix),
210
211
            prefix=f"{prefix}.h",
        )
212
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
213
214
215
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.n_embd
        )
216

217
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
218
219
        return self.wte(input_ids)

220
221
222
223
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
224
225
226
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
227
        if get_pp_group().is_first_rank:
228
229
230
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
231
                hidden_states = self.embed_input_ids(input_ids)
232
233
        else:
            hidden_states = intermediate_tensors["hidden_states"]
234
        for layer in islice(self.h, self.start_layer, self.end_layer):
235
            hidden_states = layer(position_ids, hidden_states)
236
237
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
238
239
240
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

241
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
242
243
244
245
246
247
248
249
250
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
251
        loaded_params: set[str] = set()
252
253
254
255
        for name, loaded_weight in weights:
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue

256
257
258
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
259
260
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
261
262
263
264
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
265
266
267
268
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

269
            for param_name, weight_name, shard_id in stacked_params_mapping:
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
292
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
293
294
295
296
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

297

298
class GPTJForCausalLM(nn.Module, SupportsPP):
299
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
300
        super().__init__()
301
302
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
303
        self.config = config
304
        self.quant_config = quant_config
305
        assert not config.tie_word_embeddings
306
307
308
        self.transformer = GPTJModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
309
        self.lm_head = ParallelLMHead(
310
            config.vocab_size,
311
312
            config.n_embd,
            bias=True,
313
            quant_config=quant_config,
314
            prefix=maybe_prefix(prefix, "lm_head"),
315
        )
316
        self.logits_processor = LogitsProcessor(config.vocab_size)
317
        self.make_empty_intermediate_tensors = (
318
319
            self.transformer.make_empty_intermediate_tensors
        )
320

321
322
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.embed_input_ids(input_ids)
323

324
325
326
327
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
328
329
330
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
331
332
333
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
334
335
        return hidden_states

336
337
338
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
339
    ) -> torch.Tensor | None:
340
        logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias)
341
342
        return logits

343
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
344
        loader = AutoWeightsLoader(self)
345
        return loader.load_weights(weights)