gpt_j.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-J model compatible with HuggingFace weights."""
19
from typing import List, Optional
20
21
22
23
24

import torch
from torch import nn
from transformers import GPTJConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.distributed import get_tensor_model_parallel_world_size
27
from vllm.model_executor.layers.activation import get_act_fn
28
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.rotary_embedding import get_rope
34
from vllm.model_executor.layers.sampler import Sampler
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
    ParallelLMHead, VocabParallelEmbedding)
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
38
39
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
40
from vllm.sequence import SamplerOutput
41
42
43
44


class GPTJAttention(nn.Module):

45
46
47
48
49
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
50
51
52
53
54
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

55
        self.qkv_proj = QKVParallelLinear(
56
            config.hidden_size,
57
58
            self.head_size,
            self.total_num_heads,
59
            bias=False,
60
            linear_method=linear_method,
61
62
63
64
65
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
66
            linear_method=linear_method,
67
        )
68
69
70
71
72
73

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
74
        assert getattr(config, "rotary", True)
75
        assert config.rotary_dim % 2 == 0
76
77
78
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        self.rotary_emb = get_rope(
80
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
81
            rotary_dim=config.rotary_dim,
82
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
            base=rope_theta,
            is_neox_style=False,
        )
86
        self.attn = Attention(self.num_heads, self.head_size, scaling)
87
88
89
90
91

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
92
93
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
94
95
96
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        q, k = self.rotary_emb(position_ids, q, k)
98
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
99
100
101
102
103
104
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

105
106
107
108
109
110
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
111
112
        super().__init__()
        hidden_size = config.n_embd
113
114
115
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
116
            linear_method=linear_method,
117
118
119
120
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
121
            linear_method=linear_method,
122
        )
123
124
125
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
126
127
128
129
130
131
132
133
134
135

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

136
137
138
139
140
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
141
        super().__init__()
142
143
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
144
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
145
146
        self.attn = GPTJAttention(config, linear_method)
        self.mlp = GPTJMLP(inner_dim, config, linear_method)
147
148
149
150
151

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
152
153
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
154
155
156
157
158
159
160
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
161
            attn_metadata=attn_metadata,
162
163
164
165
166
167
168
169
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


class GPTJModel(nn.Module):

170
171
172
173
174
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
175
176
177
        super().__init__()
        self.config = config
        self.embed_dim = config.n_embd
178
179
180
181
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
182
        self.h = nn.ModuleList(
183
            [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
184
185
186
187
188
189
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
190
191
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
192
193
194
195
196
197
198
199
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
200
                attn_metadata,
201
202
203
204
205
206
207
            )
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTJForCausalLM(nn.Module):

208
209
210
211
212
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
213
214
        super().__init__()
        self.config = config
215
        self.linear_method = linear_method
216
        assert not config.tie_word_embeddings
217
218
        self.transformer = GPTJModel(config, linear_method)
        self.lm_head = ParallelLMHead(
219
            config.vocab_size,
220
221
            config.n_embd,
            bias=True,
222
        )
223
224
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
225
226
227
228
229

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
230
231
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
232
    ) -> torch.Tensor:
233
        hidden_states = self.transformer(input_ids, positions, kv_caches,
234
                                         attn_metadata)
235
236
        return hidden_states

237
238
239
240
241
242
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

243
244
    def sample(
        self,
245
        logits: torch.Tensor,
246
        sampling_metadata: SamplingMetadata,
247
    ) -> Optional[SamplerOutput]:
248
        next_tokens = self.sampler(logits, sampling_metadata)
249
250
251
252
253
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
254
255
                     load_format: str = "auto",
                     revision: Optional[str] = None):
256
257
258
259
260
261
262
263
264
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
265
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
266
                model_name_or_path, cache_dir, load_format, revision):
267
268
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
269
270
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
271
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
272
273
274
275
276
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
277
278
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
279
                break
280
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
281
282
283
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
284
285
286
287
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)