gpt_j.py 12.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-J model compatible with HuggingFace weights."""
20
from typing import Iterable, Optional, Set, Tuple, Union
21
22
23
24
25

import torch
from torch import nn
from transformers import GPTJConfig

26
from vllm.attention import Attention
27
from vllm.compilation.decorators import support_torch_compile
28
from vllm.config import CacheConfig, VllmConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.rotary_embedding import get_rope
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
40
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
43

44
from .interfaces import SupportsPP
45
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
46
47
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
48

49
50
51

class GPTJAttention(nn.Module):

52
53
54
    def __init__(
        self,
        config: GPTJConfig,
55
        cache_config: Optional[CacheConfig] = None,
56
        quant_config: Optional[QuantizationConfig] = None,
57
        prefix: str = "",
58
    ):
59
60
61
62
63
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

64
        self.qkv_proj = QKVParallelLinear(
65
            config.hidden_size,
66
67
            self.head_size,
            self.total_num_heads,
68
            bias=False,
69
            quant_config=quant_config,
70
71
72
73
74
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
75
            quant_config=quant_config,
76
        )
77
78
79
80
81
82

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
83
        assert getattr(config, "rotary", True)
84
        assert config.rotary_dim % 2 == 0
85
86
87
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        self.rotary_emb = get_rope(
89
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
            rotary_dim=config.rotary_dim,
91
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
            base=rope_theta,
            is_neox_style=False,
        )
95
96
97
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
98
                              cache_config=cache_config,
99
100
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
101
102
103
104
105
106
107
108

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        q, k = self.rotary_emb(position_ids, q, k)
110
        attn_output = self.attn(q, k, v)
111
112
113
114
115
116
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

117
118
119
120
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
121
        quant_config: Optional[QuantizationConfig] = None,
122
    ):
123
124
        super().__init__()
        hidden_size = config.n_embd
125
126
127
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
128
            quant_config=quant_config,
129
130
131
132
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
133
            quant_config=quant_config,
134
        )
135
        self.act = get_act_fn(config.activation_function)
136
137
138
139
140
141
142
143
144
145

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

146
147
148
    def __init__(
        self,
        config: GPTJConfig,
149
        cache_config: Optional[CacheConfig] = None,
150
        quant_config: Optional[QuantizationConfig] = None,
151
        prefix: str = "",
152
    ):
153
        super().__init__()
154
155
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
156
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
157
158
159
160
        self.attn = GPTJAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
161
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


179
@support_torch_compile
180
181
class GPTJModel(nn.Module):

182
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
183
        super().__init__()
184
185
186
187
188

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

189
        self.config = config
190
        self.quant_config = quant_config
191
        self.embed_dim = config.n_embd
192
193
194
195
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
196
197
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
198
199
            lambda prefix: GPTJBlock(
                config, cache_config, quant_config, prefix=prefix),
200
201
            prefix=f"{prefix}.h",
        )
202
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
203
204
205
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
206

207
208
209
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

210
211
212
213
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
214
        intermediate_tensors: Optional[IntermediateTensors],
215
        inputs_embeds: Optional[torch.Tensor] = None,
216
217
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
218
219
220
221
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
222
223
        else:
            hidden_states = intermediate_tensors["hidden_states"]
224
225
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(position_ids, hidden_states)
226
227
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
228
229
230
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

288

289
class GPTJForCausalLM(nn.Module, SupportsPP):
290

291
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
292
        super().__init__()
293
294
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
295
        self.config = config
296
        self.quant_config = quant_config
297
        assert not config.tie_word_embeddings
298
299
300
        self.transformer = GPTJModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
301
        self.lm_head = ParallelLMHead(
302
            config.vocab_size,
303
304
            config.n_embd,
            bias=True,
305
            quant_config=quant_config,
306
        )
307
        self.logits_processor = LogitsProcessor(config.vocab_size)
308
309
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
310

311
312
313
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

314
315
316
317
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
318
        intermediate_tensors: Optional[IntermediateTensors] = None,
319
        inputs_embeds: Optional[torch.Tensor] = None,
320
    ) -> Union[torch.Tensor, IntermediateTensors]:
321
322
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
323
324
        return hidden_states

325
326
327
328
329
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
330
        logits = self.logits_processor(self.lm_head, hidden_states,
331
332
333
                                       sampling_metadata, self.lm_head.bias)
        return logits

334
335
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
336
337
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)