gpt_j.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-J model compatible with HuggingFace weights."""
19
from typing import List, Optional
20
21
22
23
24

import torch
from torch import nn
from transformers import GPTJConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.model_executor.layers.activation import get_act_fn
27
28
29
30
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
31
from vllm.model_executor.layers.logits_processor import LogitsProcessor
32
from vllm.model_executor.layers.rotary_embedding import get_rope
33
from vllm.model_executor.layers.sampler import Sampler
34
from vllm.model_executor.layers.vocab_parallel_embedding import (
35
    ParallelLMHead, VocabParallelEmbedding)
36
from vllm.model_executor.parallel_utils.parallel_state import (
37
    get_tensor_model_parallel_world_size)
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
40
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
41
from vllm.sequence import SamplerOutput
42
43
44
45


class GPTJAttention(nn.Module):

46
47
48
49
50
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
51
52
53
54
55
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

56
        self.qkv_proj = QKVParallelLinear(
57
            config.hidden_size,
58
59
            self.head_size,
            self.total_num_heads,
60
            bias=False,
61
            linear_method=linear_method,
62
63
64
65
66
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
67
            linear_method=linear_method,
68
        )
69
70
71
72
73
74

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
75
        assert getattr(config, "rotary", True)
76
        assert config.rotary_dim % 2 == 0
77
78
79
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        self.rotary_emb = get_rope(
81
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
82
            rotary_dim=config.rotary_dim,
83
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
            base=rope_theta,
            is_neox_style=False,
        )
87
        self.attn = Attention(self.num_heads, self.head_size, scaling)
88
89
90
91
92

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
93
94
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
95
96
97
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        q, k = self.rotary_emb(position_ids, q, k)
99
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
100
101
102
103
104
105
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

106
107
108
109
110
111
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
112
113
        super().__init__()
        hidden_size = config.n_embd
114
115
116
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
117
            linear_method=linear_method,
118
119
120
121
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
122
            linear_method=linear_method,
123
        )
124
125
126
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
127
128
129
130
131
132
133
134
135
136

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

137
138
139
140
141
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
142
        super().__init__()
143
144
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
145
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
146
147
        self.attn = GPTJAttention(config, linear_method)
        self.mlp = GPTJMLP(inner_dim, config, linear_method)
148
149
150
151
152

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
153
154
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
155
156
157
158
159
160
161
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
162
            attn_metadata=attn_metadata,
163
164
165
166
167
168
169
170
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


class GPTJModel(nn.Module):

171
172
173
174
175
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
176
177
178
        super().__init__()
        self.config = config
        self.embed_dim = config.n_embd
179
180
181
182
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
183
        self.h = nn.ModuleList(
184
            [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
185
186
187
188
189
190
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
191
192
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
193
194
195
196
197
198
199
200
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
201
                attn_metadata,
202
203
204
205
206
207
208
            )
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTJForCausalLM(nn.Module):

209
210
211
212
213
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
214
215
        super().__init__()
        self.config = config
216
        self.linear_method = linear_method
217
        assert not config.tie_word_embeddings
218
219
        self.transformer = GPTJModel(config, linear_method)
        self.lm_head = ParallelLMHead(
220
            config.vocab_size,
221
222
            config.n_embd,
            bias=True,
223
        )
224
225
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
226
227
228
229
230

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
231
232
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
233
    ) -> torch.Tensor:
234
        hidden_states = self.transformer(input_ids, positions, kv_caches,
235
                                         attn_metadata)
236
237
        return hidden_states

238
239
240
241
242
243
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

244
245
    def sample(
        self,
246
        logits: torch.Tensor,
247
        sampling_metadata: SamplingMetadata,
248
    ) -> Optional[SamplerOutput]:
249
        next_tokens = self.sampler(logits, sampling_metadata)
250
251
252
253
254
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
255
256
                     load_format: str = "auto",
                     revision: Optional[str] = None):
257
258
259
260
261
262
263
264
265
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
266
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
267
                model_name_or_path, cache_dir, load_format, revision):
268
269
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
270
271
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
272
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
273
274
275
276
277
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
278
279
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
280
                break
281
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
282
283
284
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
285
286
287
288
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)