gpt_neox.py 12.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
21

22
from collections.abc import Iterable
23
from itertools import islice
24
25
26

import torch
from torch import nn
27
28
from transformers import GPTNeoXConfig

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.activation import get_act_fn
34
35
36
37
38
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
44
45
    ParallelLMHead,
    VocabParallelEmbedding,
)
46
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
47
from vllm.sequence import IntermediateTensors
48

49
from .interfaces import SupportsPP
50
51
52
53
54
55
56
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
57

58
59

class GPTNeoXAttention(nn.Module):
60
61
62
    def __init__(
        self,
        config: GPTNeoXConfig,
63
64
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
65
        prefix: str = "",
66
    ):
67
68
69
70
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
71
        self.bias = getattr(config, "attention_bias", True)
72

73
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
74
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
75
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
76

77
        self.query_key_value = QKVParallelLinear(
78
            config.hidden_size,
79
80
            self.head_size,
            self.total_num_heads,
81
            bias=self.bias,
82
            quant_config=quant_config,
83
84
85
86
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
87
            bias=self.bias,
88
            quant_config=quant_config,
89
        )
90
        scaling = self.head_size**-0.5
91
92
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
93
        rope_theta = getattr(config, "rope_theta", 10000)
94
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        self.rotary_emb = get_rope(
96
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
99
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
100
        )
101
102
103
104
105
106
107
108
        self.attn = Attention(
            self.num_heads,
            self.head_size,
            scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
109
110
111

    def forward(
        self,
112
        position_ids: torch.Tensor,
113
114
115
116
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
117
        q, k = self.rotary_emb(position_ids, q, k)
118
        attn_output = self.attn(q, k, v)
119
120
121
122
123
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
124
125
126
    def __init__(
        self,
        config: GPTNeoXConfig,
127
        quant_config: QuantizationConfig | None = None,
128
    ):
129
        super().__init__()
130
131
132
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
133
            quant_config=quant_config,
134
135
136
137
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
138
            quant_config=quant_config,
139
        )
140
        self.act = get_act_fn(config.hidden_act)
141
142
143
144
145
146
147
148
149

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):
150
151
152
    def __init__(
        self,
        config: GPTNeoXConfig,
153
154
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
155
        prefix: str = "",
156
    ):
157
158
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
159
160
161
162
163
164
165
166
167
        self.input_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.post_attention_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.attention = GPTNeoXAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attention"
        )
168
        self.mlp = GPTNeoXMLP(config, quant_config)
169
170
171

    def forward(
        self,
172
        position_ids: torch.Tensor,
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


198
@support_torch_compile
199
class GPTNeoXModel(nn.Module):
200
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
201
        super().__init__()
202
203
204
205
206

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

207
208
        self.config = config

209
210
211
212
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
213
214
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
215
            lambda prefix: GPTNeoXLayer(
216
217
                config, cache_config, quant_config, prefix=prefix
            ),
218
219
            prefix=f"{prefix}.layers",
        )
220
221
222
223
224
225
        self.final_layer_norm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
226

227
228
229
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_in(input_ids)

230
231
    def forward(
        self,
232
233
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
234
235
236
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
237
        if get_pp_group().is_first_rank:
238
239
240
241
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
242
243
        else:
            hidden_states = intermediate_tensors["hidden_states"]
244
        for layer in islice(self.layers, self.start_layer, self.end_layer):
245
            hidden_states = layer(position_ids, hidden_states)
246
247
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
248
249
250
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states

251
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
252
        params_dict = dict(self.named_parameters())
253
        loaded_params: set[str] = set()
254
        for name, loaded_weight in weights:
255
256
257
258
259
            if (
                "attention.bias" in name
                or "attention.masked_bias" in name
                or "rotary_emb.inv_freq" in name
            ):
260
                continue
261
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]

            if "query_key_value" in name:
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
                # Thus, we need weight conversion.
                output_dim = getattr(param, "output_dim", None)
                num_heads = self.config.num_attention_heads
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
279
280
281
282
283
                        loaded_weight_shape[:output_dim]
                        + (num_heads, 3, -1)
                        + loaded_weight_shape[output_dim + 1 :]
                    )
                    loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1)
284
285
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

286
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
287
288
289
290
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

291

292
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
293
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
294
        super().__init__()
295
296
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
297
        self.config = config
298
        self.quant_config = quant_config
299
300
301
        self.gpt_neox = GPTNeoXModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt_neox")
        )
302
        self.embed_out = ParallelLMHead(
303
            config.vocab_size,
304
            config.hidden_size,
305
            quant_config=quant_config,
306
            prefix=maybe_prefix(prefix, "embed_out"),
307
        )
308
309
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
310
        self.logits_processor = LogitsProcessor(config.vocab_size)
311
        self.make_empty_intermediate_tensors = (
312
313
            self.gpt_neox.make_empty_intermediate_tensors
        )
314

315
316
317
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.gpt_neox.get_input_embeddings(input_ids)

318
319
    def forward(
        self,
320
321
        input_ids: torch.Tensor,
        positions: torch.Tensor,
322
323
324
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
325
326
327
        hidden_states = self.gpt_neox(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
328
329
        return hidden_states

330
331
332
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
333
    ) -> torch.Tensor | None:
334
        logits = self.logits_processor(self.embed_out, hidden_states)
335
336
        return logits

337
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
338
339
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)