gpt_neox.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
21

22
from collections.abc import Iterable
23
from itertools import islice
24

gaoqiong's avatar
gaoqiong committed
25
26
import os
import re
27
28
import torch
from torch import nn
29
30
from transformers import GPTNeoXConfig

31
from vllm.attention.layer import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.activation import get_act_fn
36
37
38
39
40
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
from vllm.model_executor.layers.rotary_embedding import get_rope
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
46
47
    ParallelLMHead,
    VocabParallelEmbedding,
)
48
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
from vllm.sequence import IntermediateTensors
50

51
from .interfaces import SupportsPP
52
53
54
55
56
57
58
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
gaoqiong's avatar
gaoqiong committed
59
from vllm import _custom_ops as ops
60
61
62


class GPTNeoXAttention(nn.Module):
63
64
65
    def __init__(
        self,
        config: GPTNeoXConfig,
66
67
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
68
        prefix: str = "",
69
    ):
70
71
72
73
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
74
        self.bias = getattr(config, "attention_bias", True)
75

76
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
77
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
78
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
79

80
        self.query_key_value = QKVParallelLinear(
81
            config.hidden_size,
82
83
            self.head_size,
            self.total_num_heads,
84
            bias=self.bias,
85
            quant_config=quant_config,
86
            prefix=f"{prefix}.query_key_value",
87
88
89
90
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
91
            bias=self.bias,
92
            quant_config=quant_config,
93
            prefix=f"{prefix}.dense",
94
        )
95
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        self.rotary_emb = get_rope(
97
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
            max_position=max_position_embeddings,
99
            rope_parameters=config.rope_parameters,
Woosuk Kwon's avatar
Woosuk Kwon committed
100
        )
101
        scaling = self.head_size**-0.5
102
103
104
105
106
107
108
        self.attn = Attention(
            self.num_heads,
            self.head_size,
            scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        )
110
111
112

    def forward(
        self,
113
        position_ids: torch.Tensor,
114
115
116
117
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
        q, k = self.rotary_emb(position_ids, q, k)
119
        attn_output = self.attn(q, k, v)
120
121
122
123
124
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
125
126
127
    def __init__(
        self,
        config: GPTNeoXConfig,
128
        quant_config: QuantizationConfig | None = None,
129
        prefix: str = "",
130
    ):
131
        super().__init__()
132
133
134
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
135
            quant_config=quant_config,
136
            prefix=f"{prefix}.dense_h_to_4h",
137
138
139
140
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
141
            quant_config=quant_config,
142
            prefix=f"{prefix}.dense_4h_to_h",
143
        )
144
        self.act = get_act_fn(config.hidden_act)
145
146
147
148
149
150
151
152
153

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):
154
155
156
    def __init__(
        self,
        config: GPTNeoXConfig,
157
158
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
159
        prefix: str = "",
160
    ):
161
162
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
163
164
165
166
167
168
169
170
171
        self.input_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.post_attention_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.attention = GPTNeoXAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attention"
        )
172
        self.mlp = GPTNeoXMLP(config, quant_config, prefix=f"{prefix}.mlp")
173
174
175

    def forward(
        self,
176
        position_ids: torch.Tensor,
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


202
@support_torch_compile
203
class GPTNeoXModel(nn.Module):
204
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
205
        super().__init__()
206
207
208
209
210

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

211
212
        self.config = config

213
214
215
216
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
217
218
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
219
            lambda prefix: GPTNeoXLayer(
220
221
                config, cache_config, quant_config, prefix=prefix
            ),
222
223
            prefix=f"{prefix}.layers",
        )
224
225
226
227
228
229
        self.final_layer_norm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
zhuwenwen's avatar
zhuwenwen committed
230
231
232
233
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
234

235
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
236
237
        return self.embed_in(input_ids)

238
239
    def forward(
        self,
240
        input_ids: torch.Tensor | None,
241
        position_ids: torch.Tensor,
242
243
244
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
245
        if get_pp_group().is_first_rank:
246
247
248
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
249
                hidden_states = self.embed_input_ids(input_ids)
250
251
        else:
            hidden_states = intermediate_tensors["hidden_states"]
252
        for layer in islice(self.layers, self.start_layer, self.end_layer):
253
            hidden_states = layer(position_ids, hidden_states)
254
255
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
256
257
258
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states

259
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
260
        params_dict = dict(self.named_parameters())
261
        loaded_params: set[str] = set()
262
        for name, loaded_weight in weights:
263
264
265
266
267
            if (
                "attention.bias" in name
                or "attention.masked_bias" in name
                or "rotary_emb.inv_freq" in name
            ):
268
                continue
269
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
270
271
272
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
273
274
            if is_pp_missing_parameter(name, self):
                continue
275
276
            param = params_dict[name]

277
            if "query_key_value" in name:
278
279
280
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
281
                # Thus, we need weight conversion.
282
                output_dim = getattr(param, "output_dim", None)
283
                num_heads = self.config.num_attention_heads
284
285
286
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
287
288
289
290
291
                        loaded_weight_shape[:output_dim]
                        + (num_heads, 3, -1)
                        + loaded_weight_shape[output_dim + 1 :]
                    )
                    loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1)
292
293
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

294
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
295
            weight_loader(param, loaded_weight)
296
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
297
                         
298
        return loaded_params
299

300

301
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
302
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
303
        super().__init__()
304
305
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
306
        self.config = config
307
        self.quant_config = quant_config
308
309
310
        self.gpt_neox = GPTNeoXModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt_neox")
        )
311
        self.embed_out = ParallelLMHead(
312
            config.vocab_size,
313
            config.hidden_size,
314
            quant_config=quant_config,
315
            prefix=maybe_prefix(prefix, "embed_out"),
316
        )
317
318
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
319
        self.logits_processor = LogitsProcessor(config.vocab_size)
320
        self.make_empty_intermediate_tensors = (
321
322
            self.gpt_neox.make_empty_intermediate_tensors
        )
323

324
325
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.gpt_neox.embed_input_ids(input_ids)
326

327
328
    def forward(
        self,
329
        input_ids: torch.Tensor | None,
330
        positions: torch.Tensor,
331
332
333
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
334
335
336
        hidden_states = self.gpt_neox(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
337
338
        return hidden_states

339
340
341
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
342
    ) -> torch.Tensor | None:
343
        logits = self.logits_processor(self.embed_out, hidden_states)
344
345
        return logits

346
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
347
348
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)