gpt_j.py 13.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-J model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Set, Tuple, Union
21
22
23
24
25

import torch
from torch import nn
from transformers import GPTJConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.compilation.decorators import support_torch_compile
28
from vllm.config import CacheConfig, VllmConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
37
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
41
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
44

45
46
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
49

50
51
52

class GPTJAttention(nn.Module):

53
54
55
    def __init__(
        self,
        config: GPTJConfig,
56
        cache_config: Optional[CacheConfig] = None,
57
        quant_config: Optional[QuantizationConfig] = None,
58
        prefix: str = "",
59
    ):
60
61
62
63
64
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

65
        self.qkv_proj = QKVParallelLinear(
66
            config.hidden_size,
67
68
            self.head_size,
            self.total_num_heads,
69
            bias=False,
70
            quant_config=quant_config,
71
72
73
74
75
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
76
            quant_config=quant_config,
77
        )
78
79
80
81
82
83

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
84
        assert getattr(config, "rotary", True)
85
        assert config.rotary_dim % 2 == 0
86
87
88
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        self.rotary_emb = get_rope(
90
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
            rotary_dim=config.rotary_dim,
92
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
            base=rope_theta,
            is_neox_style=False,
        )
96
97
98
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
99
                              cache_config=cache_config,
100
101
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
102
103
104
105
106

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
107
108
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
109
110
111
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        q, k = self.rotary_emb(position_ids, q, k)
113
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
114
115
116
117
118
119
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

120
121
122
123
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
124
        quant_config: Optional[QuantizationConfig] = None,
125
    ):
126
127
        super().__init__()
        hidden_size = config.n_embd
128
129
130
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
131
            quant_config=quant_config,
132
133
134
135
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
136
            quant_config=quant_config,
137
        )
138
        self.act = get_act_fn(config.activation_function)
139
140
141
142
143
144
145
146
147
148

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

149
150
151
    def __init__(
        self,
        config: GPTJConfig,
152
        cache_config: Optional[CacheConfig] = None,
153
        quant_config: Optional[QuantizationConfig] = None,
154
        prefix: str = "",
155
    ):
156
        super().__init__()
157
158
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
159
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
160
161
162
163
        self.attn = GPTJAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
164
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
165
166
167
168
169

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
170
171
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
172
173
174
175
176
177
178
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
179
            attn_metadata=attn_metadata,
180
181
182
183
184
185
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


186
@support_torch_compile
187
188
class GPTJModel(nn.Module):

189
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
190
        super().__init__()
191
192
193
194
195

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

196
197
        self.config = config
        self.embed_dim = config.n_embd
198
199
200
201
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
202
203
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
204
205
            lambda prefix: GPTJBlock(
                config, cache_config, quant_config, prefix=prefix),
206
207
            prefix=f"{prefix}.h",
        )
208
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
209
210
211
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
212

213
214
215
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

216
217
218
219
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
220
221
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
222
        intermediate_tensors: Optional[IntermediateTensors],
223
        inputs_embeds: Optional[torch.Tensor] = None,
224
225
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
226
227
228
229
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
230
231
232
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
233
234
235
236
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
237
                kv_caches[i - self.start_layer],
238
                attn_metadata,
239
            )
240
241
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
242
243
244
245
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


246
class GPTJForCausalLM(nn.Module, SupportsPP):
247

248
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
249
        super().__init__()
250
251
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
252
        self.config = config
253
        self.quant_config = quant_config
254
        assert not config.tie_word_embeddings
255
256
257
        self.transformer = GPTJModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
258
        self.lm_head = ParallelLMHead(
259
            config.vocab_size,
260
261
            config.n_embd,
            bias=True,
262
            quant_config=quant_config,
263
        )
264
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
265
        self.sampler = get_sampler()
266
267
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
268

269
270
271
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

272
273
274
275
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
276
277
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
278
        intermediate_tensors: Optional[IntermediateTensors] = None,
279
        inputs_embeds: Optional[torch.Tensor] = None,
280
    ) -> Union[torch.Tensor, IntermediateTensors]:
281
        hidden_states = self.transformer(input_ids, positions, kv_caches,
282
283
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
284
285
        return hidden_states

286
287
288
289
290
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
291
        logits = self.logits_processor(self.lm_head, hidden_states,
292
293
294
                                       sampling_metadata, self.lm_head.bias)
        return logits

295
296
    def sample(
        self,
297
        logits: torch.Tensor,
298
        sampling_metadata: SamplingMetadata,
299
    ) -> Optional[SamplerOutput]:
300
        next_tokens = self.sampler(logits, sampling_metadata)
301
302
        return next_tokens

303
304
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
305
306
307
308
309
310
311
312
313
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
314
        loaded_params: Set[str] = set()
315
        for name, loaded_weight in weights:
316
317
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
318
319
320

            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
321
                # Loading kv cache quantization scales
322
323
324
325
326
327
328
329
330
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

331
332
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
333
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
334
335
336
337
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
338
339
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
340
                param = params_dict[name]
341
342
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
343
                break
344
            else:
345
346
347
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
348
349
350
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
351
352
                if is_pp_missing_parameter(name, self):
                    continue
353
354
355
356
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
357
358
            loaded_params.add(name)
        return loaded_params