gpt_j.py 9.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
23
from typing import List, Optional, Tuple
24
25
26
27
28
29
30
31

import torch
from torch import nn
from transformers import GPTJConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
32
33
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.sampler import Sampler
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
39
from vllm.model_executor.parallel_utils.parallel_state import (
40
41
42
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
43
from vllm.sequence import SamplerOutput
44
45
46
47
48
49

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPTJAttention(nn.Module):

50
51
52
53
54
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
55
56
57
58
59
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

60
        self.qkv_proj = QKVParallelLinear(
61
            config.hidden_size,
62
63
            self.head_size,
            self.total_num_heads,
64
            bias=False,
65
            linear_method=linear_method,
66
67
68
69
70
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
71
            linear_method=linear_method,
72
        )
73
74
75
76
77
78

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
79
        assert getattr(config, "rotary", True)
80
        assert config.rotary_dim % 2 == 0
81
82
83
84
85
86
87
88
89
90
91
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_size,
            scaling,
            config.rotary_dim,
            base=rope_theta,
            max_position=max_position_embeddings,
            is_neox_style=False)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        self.warmup = False

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        k_cache, v_cache = kv_cache
        attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

113
114
115
116
117
118
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
119
120
        super().__init__()
        hidden_size = config.n_embd
121
122
123
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
124
            linear_method=linear_method,
125
126
127
128
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
129
            linear_method=linear_method,
130
        )
131
132
133
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
134
135
136
137
138
139
140
141
142
143

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

144
145
146
147
148
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
149
        super().__init__()
150
        inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner
151
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
152
153
        self.attn = GPTJAttention(config, linear_method)
        self.mlp = GPTJMLP(inner_dim, config, linear_method)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


class GPTJModel(nn.Module):

179
180
181
182
183
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
184
185
186
        super().__init__()
        self.config = config
        self.embed_dim = config.n_embd
187
188
189
190
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
191
        self.h = nn.ModuleList(
192
            [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
193
194
195
196
197
198
199
200
201
202
203
204
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.h)):
205
            cache_event = None if cache_events is None else cache_events[i]
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTJForCausalLM(nn.Module):

220
221
222
223
224
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
225
226
        super().__init__()
        self.config = config
227
        self.linear_method = linear_method
228
        assert not config.tie_word_embeddings
229
230
        self.transformer = GPTJModel(config, linear_method)
        self.lm_head = ParallelLMHead(
231
            config.vocab_size,
232
233
            config.n_embd,
            bias=True,
234
        )
235
236
237
238
239
240
241
242
243
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
244
    ) -> SamplerOutput:
245
246
247
248
249
250
251
252
253
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata, self.lm_head.bias)
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
254
255
                     load_format: str = "auto",
                     revision: Optional[str] = None):
256
257
258
259
260
261
262
263
264
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
265
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
266
                model_name_or_path, cache_dir, load_format, revision):
267
268
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
269
270
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
271
                    continue
272
273
274
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
275
                break
276
277
278
279
280
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)