"examples/paligemma_example.py" did not exist on "9389380015b80c109b899a08840132780b9b3fc0"
gpt_j.py 9.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-J model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple
20
21
22
23
24

import torch
from torch import nn
from transformers import GPTJConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.distributed import get_tensor_model_parallel_world_size
27
from vllm.model_executor.layers.activation import get_act_fn
28
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.rotary_embedding import get_rope
34
from vllm.model_executor.layers.sampler import Sampler
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
    ParallelLMHead, VocabParallelEmbedding)
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
from vllm.sequence import SamplerOutput
40
41
42
43


class GPTJAttention(nn.Module):

44
45
46
47
48
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
49
50
51
52
53
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

54
        self.qkv_proj = QKVParallelLinear(
55
            config.hidden_size,
56
57
            self.head_size,
            self.total_num_heads,
58
            bias=False,
59
            linear_method=linear_method,
60
61
62
63
64
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
65
            linear_method=linear_method,
66
        )
67
68
69
70
71
72

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
73
        assert getattr(config, "rotary", True)
74
        assert config.rotary_dim % 2 == 0
75
76
77
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
        self.rotary_emb = get_rope(
79
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
80
            rotary_dim=config.rotary_dim,
81
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
            base=rope_theta,
            is_neox_style=False,
        )
85
        self.attn = Attention(self.num_heads, self.head_size, scaling)
86
87
88
89
90

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
91
92
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
93
94
95
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        q, k = self.rotary_emb(position_ids, q, k)
97
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
98
99
100
101
102
103
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

104
105
106
107
108
109
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
110
111
        super().__init__()
        hidden_size = config.n_embd
112
113
114
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
115
            linear_method=linear_method,
116
117
118
119
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
120
            linear_method=linear_method,
121
        )
122
123
124
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)
125
126
127
128
129
130
131
132
133
134

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

135
136
137
138
139
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
140
        super().__init__()
141
142
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
143
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
144
145
        self.attn = GPTJAttention(config, linear_method)
        self.mlp = GPTJMLP(inner_dim, config, linear_method)
146
147
148
149
150

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
151
152
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
153
154
155
156
157
158
159
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
160
            attn_metadata=attn_metadata,
161
162
163
164
165
166
167
168
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


class GPTJModel(nn.Module):

169
170
171
172
173
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
174
175
176
        super().__init__()
        self.config = config
        self.embed_dim = config.n_embd
177
178
179
180
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
181
        self.h = nn.ModuleList(
182
            [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
183
184
185
186
187
188
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
189
190
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
191
192
193
194
195
196
197
198
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
199
                attn_metadata,
200
201
202
203
204
205
206
            )
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class GPTJForCausalLM(nn.Module):

207
208
209
210
211
    def __init__(
        self,
        config: GPTJConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
212
213
        super().__init__()
        self.config = config
214
        self.linear_method = linear_method
215
        assert not config.tie_word_embeddings
216
217
        self.transformer = GPTJModel(config, linear_method)
        self.lm_head = ParallelLMHead(
218
            config.vocab_size,
219
220
            config.n_embd,
            bias=True,
221
        )
222
223
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
224
225
226
227
228

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
229
230
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
231
    ) -> torch.Tensor:
232
        hidden_states = self.transformer(input_ids, positions, kv_caches,
233
                                         attn_metadata)
234
235
        return hidden_states

236
237
238
239
240
241
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

242
243
    def sample(
        self,
244
        logits: torch.Tensor,
245
        sampling_metadata: SamplingMetadata,
246
    ) -> Optional[SamplerOutput]:
247
        next_tokens = self.sampler(logits, sampling_metadata)
248
249
        return next_tokens

250
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
251
252
253
254
255
256
257
258
259
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
260
        for name, loaded_weight in weights:
261
262
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
263
264
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
265
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
266
267
268
269
270
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
271
272
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
273
                break
274
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
275
276
277
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
278
279
280
281
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)