gpt_j.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-J model compatible with HuggingFace weights."""
19
from typing import List, Optional, Tuple
20
21
22
23
24
25
26

import torch
from torch import nn
from transformers import GPTJConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
27
from vllm.model_executor.layers.attention import Attention
28
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.rotary_embedding import get_rope
33
from vllm.model_executor.layers.sampler import Sampler
34
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
36
from vllm.model_executor.parallel_utils.parallel_state import (
37
    get_tensor_model_parallel_world_size)
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
40
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
41
from vllm.sequence import SamplerOutput
42
43
44
45
46
47

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPTJAttention(nn.Module):

48
49
50
51
52
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
53
54
55
56
57
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

58
        self.qkv_proj = QKVParallelLinear(
59
            config.hidden_size,
60
61
            self.head_size,
            self.total_num_heads,
62
            bias=False,
63
            linear_method=linear_method,
64
65
66
67
68
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
69
            linear_method=linear_method,
70
        )
71
72
73
74
75
76

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
77
        assert getattr(config, "rotary", True)
78
        assert config.rotary_dim % 2 == 0
79
80
81
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        self.rotary_emb = get_rope(
83
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
84
            rotary_dim=config.rotary_dim,
85
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
            base=rope_theta,
            is_neox_style=False,
        )
89
        self.attn = Attention(self.num_heads, self.head_size, scaling)
90
91
92
93
94
95
96
97
98
99

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
        q, k = self.rotary_emb(position_ids, q, k)
101
        k_cache, v_cache = kv_cache
102
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
103
104
105
106
107
108
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

109
110
111
112
113
114
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
115
116
        super().__init__()
        hidden_size = config.n_embd
117
118
119
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
120
            linear_method=linear_method,
121
122
123
124
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
125
            linear_method=linear_method,
126
        )
127
128
129
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
130
131
132
133
134
135
136
137
138
139

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

140
141
142
143
144
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
145
        super().__init__()
146
147
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
148
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
149
150
        self.attn = GPTJAttention(config, linear_method)
        self.mlp = GPTJMLP(inner_dim, config, linear_method)
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


class GPTJModel(nn.Module):

174
175
176
177
178
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
179
180
181
        super().__init__()
        self.config = config
        self.embed_dim = config.n_embd
182
183
184
185
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
186
        self.h = nn.ModuleList(
187
            [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
            )
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTJForCausalLM(nn.Module):

212
213
214
215
216
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
217
218
        super().__init__()
        self.config = config
219
        self.linear_method = linear_method
220
        assert not config.tie_word_embeddings
221
222
        self.transformer = GPTJModel(config, linear_method)
        self.lm_head = ParallelLMHead(
223
            config.vocab_size,
224
225
            config.n_embd,
            bias=True,
226
        )
227
228
229
230
231
232
233
234
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
235
    ) -> torch.Tensor:
236
        hidden_states = self.transformer(input_ids, positions, kv_caches,
237
                                         input_metadata)
238
239
240
241
242
243
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
244
    ) -> Optional[SamplerOutput]:
245
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
246
                                   sampling_metadata, self.lm_head.bias)
247
248
249
250
251
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
252
253
                     load_format: str = "auto",
                     revision: Optional[str] = None):
254
255
256
257
258
259
260
261
262
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
263
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
264
                model_name_or_path, cache_dir, load_format, revision):
265
266
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
267
268
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
269
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
270
271
272
273
274
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
275
276
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
277
                break
278
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
279
280
281
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
282
283
284
285
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)