"vllm/vscode:/vscode.git/clone" did not exist on "ffadd03540db4f54be7c010b6b4549a13dcad1a4"
gpt_j.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-J model compatible with HuggingFace weights."""
21
from collections.abc import Iterable
22
from itertools import islice
23
from typing import Optional, Union
24
25
26
27
28

import torch
from torch import nn
from transformers import GPTJConfig

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.activation import get_act_fn
34
35
36
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.rotary_embedding import get_rope
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
43
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
from vllm.sequence import IntermediateTensors
46

47
from .interfaces import SupportsPP
48
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
49
50
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
51

52
53
54

class GPTJAttention(nn.Module):

55
56
57
    def __init__(
        self,
        config: GPTJConfig,
58
        cache_config: Optional[CacheConfig] = None,
59
        quant_config: Optional[QuantizationConfig] = None,
60
        prefix: str = "",
61
    ):
62
63
64
65
66
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

67
        self.qkv_proj = QKVParallelLinear(
68
            config.hidden_size,
69
70
            self.head_size,
            self.total_num_heads,
71
            bias=False,
72
            quant_config=quant_config,
73
74
75
76
77
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
78
            quant_config=quant_config,
79
        )
80
81
82
83
84
85

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
86
        assert getattr(config, "rotary", True)
87
        assert config.rotary_dim % 2 == 0
88
89
90
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        self.rotary_emb = get_rope(
92
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
            rotary_dim=config.rotary_dim,
94
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
            base=rope_theta,
            is_neox_style=False,
        )
98
99
100
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
101
                              cache_config=cache_config,
102
103
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
104
105
106
107
108
109
110
111

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        q, k = self.rotary_emb(position_ids, q, k)
113
        attn_output = self.attn(q, k, v)
114
115
116
117
118
119
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

120
121
122
123
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
124
        quant_config: Optional[QuantizationConfig] = None,
125
    ):
126
127
        super().__init__()
        hidden_size = config.n_embd
128
129
130
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
131
            quant_config=quant_config,
132
133
134
135
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
136
            quant_config=quant_config,
137
        )
138
        self.act = get_act_fn(config.activation_function)
139
140
141
142
143
144
145
146
147
148

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

149
150
151
    def __init__(
        self,
        config: GPTJConfig,
152
        cache_config: Optional[CacheConfig] = None,
153
        quant_config: Optional[QuantizationConfig] = None,
154
        prefix: str = "",
155
    ):
156
        super().__init__()
157
158
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
159
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
160
161
162
163
        self.attn = GPTJAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
164
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


182
@support_torch_compile
183
184
class GPTJModel(nn.Module):

185
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
186
        super().__init__()
187
188
189
190
191

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

192
        self.config = config
193
        self.quant_config = quant_config
194
        self.embed_dim = config.n_embd
195
196
197
198
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
199
200
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
201
202
            lambda prefix: GPTJBlock(
                config, cache_config, quant_config, prefix=prefix),
203
204
            prefix=f"{prefix}.h",
        )
205
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
206
207
208
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
209

210
211
212
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

213
214
215
216
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
217
        intermediate_tensors: Optional[IntermediateTensors],
218
        inputs_embeds: Optional[torch.Tensor] = None,
219
220
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
221
222
223
224
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
225
226
        else:
            hidden_states = intermediate_tensors["hidden_states"]
227
        for layer in islice(self.h, self.start_layer, self.end_layer):
228
            hidden_states = layer(position_ids, hidden_states)
229
230
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
231
232
233
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

234
235
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
236
237
238
239
240
241
242
243
244
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
245
        loaded_params: set[str] = set()
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        for name, loaded_weight in weights:
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

291

292
class GPTJForCausalLM(nn.Module, SupportsPP):
293

294
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
295
        super().__init__()
296
297
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
298
        self.config = config
299
        self.quant_config = quant_config
300
        assert not config.tie_word_embeddings
301
302
303
        self.transformer = GPTJModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
304
        self.lm_head = ParallelLMHead(
305
            config.vocab_size,
306
307
            config.n_embd,
            bias=True,
308
            quant_config=quant_config,
309
            prefix=maybe_prefix(prefix, "lm_head"),
310
        )
311
        self.logits_processor = LogitsProcessor(config.vocab_size)
312
313
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
314

315
316
317
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

318
319
320
321
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
322
        intermediate_tensors: Optional[IntermediateTensors] = None,
323
        inputs_embeds: Optional[torch.Tensor] = None,
324
    ) -> Union[torch.Tensor, IntermediateTensors]:
325
326
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
327
328
        return hidden_states

329
330
331
332
333
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
334
        logits = self.logits_processor(self.lm_head, hidden_states,
335
336
337
                                       sampling_metadata, self.lm_head.bias)
        return logits

338
339
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
340
        loader = AutoWeightsLoader(self)
341
        return loader.load_weights(weights)