gpt_j.py 12.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
17
"""Inference-only GPT-J model compatible with HuggingFace weights."""
18
from typing import Iterable, List, Optional, Set, Tuple, Union
19
20
21
22
23

import torch
from torch import nn
from transformers import GPTJConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.compilation.decorators import support_torch_compile
26
from vllm.config import CacheConfig, VllmConfig
27
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
35
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
    ParallelLMHead, VocabParallelEmbedding)
38
39
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
42

43
44
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
45
46
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
47

48
49
50

class GPTJAttention(nn.Module):

51
52
53
    def __init__(
        self,
        config: GPTJConfig,
54
        cache_config: Optional[CacheConfig] = None,
55
        quant_config: Optional[QuantizationConfig] = None,
56
        prefix: str = "",
57
    ):
58
59
60
61
62
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

63
        self.qkv_proj = QKVParallelLinear(
64
            config.hidden_size,
65
66
            self.head_size,
            self.total_num_heads,
67
            bias=False,
68
            quant_config=quant_config,
69
70
71
72
73
        )
        self.out_proj = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
74
            quant_config=quant_config,
75
        )
76
77
78
79
80
81

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        scaling = self.head_size**-0.5
82
        assert getattr(config, "rotary", True)
83
        assert config.rotary_dim % 2 == 0
84
85
86
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        self.rotary_emb = get_rope(
88
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
89
            rotary_dim=config.rotary_dim,
90
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
            base=rope_theta,
            is_neox_style=False,
        )
94
95
96
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
97
                              cache_config=cache_config,
98
99
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
100
101
102
103
104

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
105
106
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
107
108
109
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
110
        q, k = self.rotary_emb(position_ids, q, k)
111
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
112
113
114
115
116
117
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class GPTJMLP(nn.Module):

118
119
120
121
    def __init__(
        self,
        intermediate_size: int,
        config: GPTJConfig,
122
        quant_config: Optional[QuantizationConfig] = None,
123
    ):
124
125
        super().__init__()
        hidden_size = config.n_embd
126
127
128
        self.fc_in = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
129
            quant_config=quant_config,
130
131
132
133
        )
        self.fc_out = RowParallelLinear(
            intermediate_size,
            hidden_size,
134
            quant_config=quant_config,
135
        )
136
        self.act = get_act_fn(config.activation_function)
137
138
139
140
141
142
143
144
145
146

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc_out(hidden_states)
        return hidden_states


class GPTJBlock(nn.Module):

147
148
149
    def __init__(
        self,
        config: GPTJConfig,
150
        cache_config: Optional[CacheConfig] = None,
151
        quant_config: Optional[QuantizationConfig] = None,
152
        prefix: str = "",
153
    ):
154
        super().__init__()
155
156
        inner_dim = (4 * config.n_embd
                     if config.n_inner is None else config.n_inner)
157
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
158
159
160
161
        self.attn = GPTJAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
162
        self.mlp = GPTJMLP(inner_dim, config, quant_config)
163
164
165
166
167

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
168
169
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
170
171
172
173
174
175
176
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
177
            attn_metadata=attn_metadata,
178
179
180
181
182
183
        )
        mlp_output = self.mlp(hidden_states)
        hidden_states = attn_output + mlp_output + residual
        return hidden_states


184
@support_torch_compile
185
186
class GPTJModel(nn.Module):

187
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
188
        super().__init__()
189
190
191
192
193

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

194
195
        self.config = config
        self.embed_dim = config.n_embd
196
197
198
199
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            self.embed_dim,
        )
200
201
        self.start_layer, self.end_layer, self.h = make_layers(
            config.n_layer,
202
203
            lambda prefix: GPTJBlock(
                config, cache_config, quant_config, prefix=prefix),
204
205
            prefix=f"{prefix}.h",
        )
206
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
207
208
209
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
210

211
212
213
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

214
215
216
217
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
218
219
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
220
        intermediate_tensors: Optional[IntermediateTensors],
221
        inputs_embeds: Optional[torch.Tensor] = None,
222
223
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
224
225
226
227
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
228
229
230
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
231
232
233
234
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
235
                kv_caches[i - self.start_layer],
236
                attn_metadata,
237
            )
238
239
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
240
241
242
243
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


244
class GPTJForCausalLM(nn.Module, SupportsPP):
245

246
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
247
        super().__init__()
248
249
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
250
        self.config = config
251
        self.quant_config = quant_config
252
        assert not config.tie_word_embeddings
253
254
255
        self.transformer = GPTJModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
256
        self.lm_head = ParallelLMHead(
257
            config.vocab_size,
258
259
            config.n_embd,
            bias=True,
260
            quant_config=quant_config,
261
        )
262
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
263
        self.sampler = get_sampler()
264
265
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
266

267
268
269
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

270
271
272
273
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
274
275
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
276
        intermediate_tensors: Optional[IntermediateTensors] = None,
277
        inputs_embeds: Optional[torch.Tensor] = None,
278
    ) -> Union[torch.Tensor, IntermediateTensors]:
279
        hidden_states = self.transformer(input_ids, positions, kv_caches,
280
281
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
282
283
        return hidden_states

284
285
286
287
288
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
289
        logits = self.logits_processor(self.lm_head, hidden_states,
290
291
292
                                       sampling_metadata, self.lm_head.bias)
        return logits

293
294
    def sample(
        self,
295
        logits: torch.Tensor,
296
        sampling_metadata: SamplingMetadata,
297
    ) -> Optional[SamplerOutput]:
298
        next_tokens = self.sampler(logits, sampling_metadata)
299
300
        return next_tokens

301
302
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
303
304
305
306
307
308
309
310
311
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
312
        loaded_params: Set[str] = set()
313
        for name, loaded_weight in weights:
314
315
            if "attn.bias" in name or "attn.masked_bias" in name:
                continue
316
317
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
318
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
319
320
321
322
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
323
324
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
325
                param = params_dict[name]
326
327
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
328
                break
329
            else:
330
331
332
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
333
334
335
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
336
337
                if is_pp_missing_parameter(name, self):
                    continue
338
339
340
341
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
342
343
            loaded_params.add(name)
        return loaded_params