gpt_j.py 12.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
17
"""Inference-only GPT-J model compatible with HuggingFace weights."""
18
from typing import Iterable, List, Optional, Tuple, Union
19
20
21
22
23

import torch
from torch import nn
from transformers import GPTJConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.compilation.decorators import support_torch_compile
26
from vllm.config import CacheConfig, VllmConfig
27
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
35
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
    ParallelLMHead, VocabParallelEmbedding)
38
39
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
42

43
44
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
45
46
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
47

48
49
50

class GPTJAttention(nn.Module):

51
52
53
    def __init__(
        self,
        config: GPTJConfig,
54
        cache_config: Optional[CacheConfig] = None,
55
        quant_config: Optional[QuantizationConfig] = None,
56
    ):
57
58
59
60
61
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

62
        self.qkv_proj = QKVParallelLinear(
63
            config.hidden_size,
64
65
            self.head_size,
            self.total_num_heads,
66
            bias=False,
67
            quant_config=quant_config,
68
69
70
71
72
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
73
            quant_config=quant_config,
74
        )
75
76
77
78
79
80

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
81
        assert getattr(config, "rotary", True)
82
        assert config.rotary_dim % 2 == 0
83
84
85
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        self.rotary_emb = get_rope(
87
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
            rotary_dim=config.rotary_dim,
89
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
            base=rope_theta,
            is_neox_style=False,
        )
93
94
95
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
96
97
                              cache_config=cache_config,
                              quant_config=quant_config)
98
99
100
101
102

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
103
104
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
105
106
107
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        q, k = self.rotary_emb(position_ids, q, k)
109
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
110
111
112
113
114
115
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

116
117
118
119
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
120
        quant_config: Optional[QuantizationConfig] = None,
121
    ):
122
123
        super().__init__()
        hidden_size = config.n_embd
124
125
126
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
127
            quant_config=quant_config,
128
129
130
131
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
132
            quant_config=quant_config,
133
        )
134
        self.act = get_act_fn(config.activation_function)
135
136
137
138
139
140
141
142
143
144

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

145
146
147
    def __init__(
        self,
        config: GPTJConfig,
148
        cache_config: Optional[CacheConfig] = None,
149
        quant_config: Optional[QuantizationConfig] = None,
150
    ):
151
        super().__init__()
152
153
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
154
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
155
        self.attn = GPTJAttention(config, cache_config, quant_config)
156
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
157
158
159
160
161

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
162
163
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
164
165
166
167
168
169
170
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
171
            attn_metadata=attn_metadata,
172
173
174
175
176
177
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


178
@support_torch_compile
179
180
class GPTJModel(nn.Module):

181
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
182
        super().__init__()
183
184
185
186
187

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

188
189
        self.config = config
        self.embed_dim = config.n_embd
190
191
192
193
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
194
195
196
197
198
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
            lambda prefix: GPTJBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.h",
        )
199
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
200
201
202
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
203

204
205
206
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

207
208
209
210
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
211
212
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
213
        intermediate_tensors: Optional[IntermediateTensors],
214
        inputs_embeds: Optional[torch.Tensor] = None,
215
216
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
217
218
219
220
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
221
222
223
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
224
225
226
227
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
228
                kv_caches[i - self.start_layer],
229
                attn_metadata,
230
            )
231
232
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
233
234
235
236
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


237
class GPTJForCausalLM(nn.Module, SupportsPP):
238

239
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
240
        super().__init__()
241
242
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
243
        self.config = config
244
        self.quant_config = quant_config
245
        assert not config.tie_word_embeddings
246
247
248
        self.transformer = GPTJModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
249
        self.lm_head = ParallelLMHead(
250
            config.vocab_size,
251
252
            config.n_embd,
            bias=True,
253
            quant_config=quant_config,
254
        )
255
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
256
        self.sampler = get_sampler()
257
258
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
259

260
261
262
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

263
264
265
266
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
267
268
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
269
        intermediate_tensors: Optional[IntermediateTensors] = None,
270
        inputs_embeds: Optional[torch.Tensor] = None,
271
    ) -> Union[torch.Tensor, IntermediateTensors]:
272
        hidden_states = self.transformer(input_ids, positions, kv_caches,
273
274
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
275
276
        return hidden_states

277
278
279
280
281
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
282
        logits = self.logits_processor(self.lm_head, hidden_states,
283
284
285
                                       sampling_metadata, self.lm_head.bias)
        return logits

286
287
    def sample(
        self,
288
        logits: torch.Tensor,
289
        sampling_metadata: SamplingMetadata,
290
    ) -> Optional[SamplerOutput]:
291
        next_tokens = self.sampler(logits, sampling_metadata)
292
293
        return next_tokens

294
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
295
296
297
298
299
300
301
302
303
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
304
        for name, loaded_weight in weights:
305
306
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
307
308
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
309
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
310
311
312
313
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
314
315
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
316
                param = params_dict[name]
317
318
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
319
                break
320
            else:
321
322
323
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
324
325
326
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
327
328
                if is_pp_missing_parameter(name, self):
                    continue
329
330
331
332
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)