gpt_j.py 12.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-J model compatible with HuggingFace weights."""
20
21
from collections.abc import Iterable
from typing import Optional, Union
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTJConfig

27
from vllm.attention import Attention
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, VllmConfig
30
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
31
from vllm.model_executor.layers.activation import get_act_fn
32
33
34
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
from vllm.model_executor.layers.rotary_embedding import get_rope
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
41
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
44

45
from .interfaces import SupportsPP
46
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
49

50
51
52

class GPTJAttention(nn.Module):

53
54
55
    def __init__(
        self,
        config: GPTJConfig,
56
        cache_config: Optional[CacheConfig] = None,
57
        quant_config: Optional[QuantizationConfig] = None,
58
        prefix: str = "",
59
    ):
60
61
62
63
64
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

65
        self.qkv_proj = QKVParallelLinear(
66
            config.hidden_size,
67
68
            self.head_size,
            self.total_num_heads,
69
            bias=False,
70
            quant_config=quant_config,
71
72
73
74
75
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
76
            quant_config=quant_config,
77
        )
78
79
80
81
82
83

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
84
        assert getattr(config, "rotary", True)
85
        assert config.rotary_dim % 2 == 0
86
87
88
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        self.rotary_emb = get_rope(
90
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
            rotary_dim=config.rotary_dim,
92
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
            base=rope_theta,
            is_neox_style=False,
        )
96
97
98
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
99
                              cache_config=cache_config,
100
101
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
102
103
104
105
106
107
108
109

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
110
        q, k = self.rotary_emb(position_ids, q, k)
111
        attn_output = self.attn(q, k, v)
112
113
114
115
116
117
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

118
119
120
121
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
122
        quant_config: Optional[QuantizationConfig] = None,
123
    ):
124
125
        super().__init__()
        hidden_size = config.n_embd
126
127
128
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
129
            quant_config=quant_config,
130
131
132
133
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
134
            quant_config=quant_config,
135
        )
136
        self.act = get_act_fn(config.activation_function)
137
138
139
140
141
142
143
144
145
146

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

147
148
149
    def __init__(
        self,
        config: GPTJConfig,
150
        cache_config: Optional[CacheConfig] = None,
151
        quant_config: Optional[QuantizationConfig] = None,
152
        prefix: str = "",
153
    ):
154
        super().__init__()
155
156
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
157
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
158
159
160
161
        self.attn = GPTJAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
162
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


180
@support_torch_compile
181
182
class GPTJModel(nn.Module):

183
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
184
        super().__init__()
185
186
187
188
189

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

190
        self.config = config
191
        self.quant_config = quant_config
192
        self.embed_dim = config.n_embd
193
194
195
196
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
197
198
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
199
200
            lambda prefix: GPTJBlock(
                config, cache_config, quant_config, prefix=prefix),
201
202
            prefix=f"{prefix}.h",
        )
203
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
204
205
206
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
207

208
209
210
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

211
212
213
214
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
215
        intermediate_tensors: Optional[IntermediateTensors],
216
        inputs_embeds: Optional[torch.Tensor] = None,
217
218
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
219
220
221
222
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
223
224
        else:
            hidden_states = intermediate_tensors["hidden_states"]
225
226
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(position_ids, hidden_states)
227
228
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
229
230
231
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

232
233
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
234
235
236
237
238
239
240
241
242
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
243
        loaded_params: set[str] = set()
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        for name, loaded_weight in weights:
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

289

290
class GPTJForCausalLM(nn.Module, SupportsPP):
291

292
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
293
        super().__init__()
294
295
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
296
        self.config = config
297
        self.quant_config = quant_config
298
        assert not config.tie_word_embeddings
299
300
301
        self.transformer = GPTJModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
302
        self.lm_head = ParallelLMHead(
303
            config.vocab_size,
304
305
            config.n_embd,
            bias=True,
306
            quant_config=quant_config,
307
        )
308
        self.logits_processor = LogitsProcessor(config.vocab_size)
309
310
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
311

312
313
314
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

315
316
317
318
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
319
        intermediate_tensors: Optional[IntermediateTensors] = None,
320
        inputs_embeds: Optional[torch.Tensor] = None,
321
    ) -> Union[torch.Tensor, IntermediateTensors]:
322
323
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
324
325
        return hidden_states

326
327
328
329
330
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
331
        logits = self.logits_processor(self.lm_head, hidden_states,
332
333
334
                                       sampling_metadata, self.lm_head.bias)
        return logits

335
336
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
337
338
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)