gpt_j.py 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-J model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple
20
21
22
23
24

import torch
from torch import nn
from transformers import GPTJConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.config import CacheConfig
27
from vllm.distributed import get_tensor_model_parallel_world_size
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
34
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
35
from vllm.model_executor.layers.rotary_embedding import get_rope
36
from vllm.model_executor.layers.sampler import Sampler
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import SamplerOutput
42
43
44
45


class GPTJAttention(nn.Module):

46
47
48
    def __init__(
        self,
        config: GPTJConfig,
49
        cache_config: Optional[CacheConfig] = None,
50
        quant_config: Optional[QuantizationConfig] = None,
51
    ):
52
53
54
55
56
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

57
        self.qkv_proj = QKVParallelLinear(
58
            config.hidden_size,
59
60
            self.head_size,
            self.total_num_heads,
61
            bias=False,
62
            quant_config=quant_config,
63
64
65
66
67
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
68
            quant_config=quant_config,
69
        )
70
71
72
73
74
75

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
76
        assert getattr(config, "rotary", True)
77
        assert config.rotary_dim % 2 == 0
78
79
80
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        self.rotary_emb = get_rope(
82
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
83
            rotary_dim=config.rotary_dim,
84
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
            base=rope_theta,
            is_neox_style=False,
        )
88
89
90
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
91
92
                              cache_config=cache_config,
                              quant_config=quant_config)
93
94
95
96
97

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
98
99
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
100
101
102
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        q, k = self.rotary_emb(position_ids, q, k)
104
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
105
106
107
108
109
110
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

111
112
113
114
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
115
        quant_config: Optional[QuantizationConfig] = None,
116
    ):
117
118
        super().__init__()
        hidden_size = config.n_embd
119
120
121
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
122
            quant_config=quant_config,
123
124
125
126
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
127
            quant_config=quant_config,
128
        )
129
130
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
131
132
133
134
135
136
137
138
139
140

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

141
142
143
    def __init__(
        self,
        config: GPTJConfig,
144
        cache_config: Optional[CacheConfig] = None,
145
        quant_config: Optional[QuantizationConfig] = None,
146
    ):
147
        super().__init__()
148
149
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
150
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
151
        self.attn = GPTJAttention(config, cache_config, quant_config)
152
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
153
154
155
156
157

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
158
159
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
160
161
162
163
164
165
166
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
167
            attn_metadata=attn_metadata,
168
169
170
171
172
173
174
175
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


class GPTJModel(nn.Module):

176
177
178
    def __init__(
        self,
        config: GPTJConfig,
179
        cache_config: Optional[CacheConfig] = None,
180
        quant_config: Optional[QuantizationConfig] = None,
181
    ):
182
183
184
        super().__init__()
        self.config = config
        self.embed_dim = config.n_embd
185
186
187
188
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
189
190
191
192
        self.h = nn.ModuleList([
            GPTJBlock(config, cache_config, quant_config)
            for _ in range(config.n_layer)
        ])
193
194
195
196
197
198
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
199
200
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
201
202
203
204
205
206
207
208
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
209
                attn_metadata,
210
211
212
213
214
215
216
            )
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTJForCausalLM(nn.Module):

217
218
219
    def __init__(
        self,
        config: GPTJConfig,
220
        cache_config: Optional[CacheConfig] = None,
221
        quant_config: Optional[QuantizationConfig] = None,
222
    ):
223
224
        super().__init__()
        self.config = config
225
        self.quant_config = quant_config
226
        assert not config.tie_word_embeddings
227
        self.transformer = GPTJModel(config, cache_config, quant_config)
228
        self.lm_head = ParallelLMHead(
229
            config.vocab_size,
230
231
            config.n_embd,
            bias=True,
232
        )
233
234
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
235
236
237
238
239

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
240
241
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
242
    ) -> torch.Tensor:
243
        hidden_states = self.transformer(input_ids, positions, kv_caches,
244
                                         attn_metadata)
245
246
        return hidden_states

247
248
249
250
251
252
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

253
254
    def sample(
        self,
255
        logits: torch.Tensor,
256
        sampling_metadata: SamplingMetadata,
257
    ) -> Optional[SamplerOutput]:
258
        next_tokens = self.sampler(logits, sampling_metadata)
259
260
        return next_tokens

261
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
262
263
264
265
266
267
268
269
270
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
271
        for name, loaded_weight in weights:
272
273
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
274
275
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
276
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
277
278
279
280
281
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
282
283
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
284
                break
285
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
286
287
288
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
289
290
291
292
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)