gpt_neox.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
21
from collections.abc import Iterable
22
from itertools import islice
23
from typing import Optional, Union
24
25
26

import torch
from torch import nn
27
28
from transformers import GPTNeoXConfig

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.activation import get_act_fn
34
35
36
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.rotary_embedding import get_rope
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors
45

46
from .interfaces import SupportsPP
47
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
48
49
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
50

51
52
53

class GPTNeoXAttention(nn.Module):

54
55
56
    def __init__(
        self,
        config: GPTNeoXConfig,
57
        cache_config: Optional[CacheConfig] = None,
58
        quant_config: Optional[QuantizationConfig] = None,
59
        prefix: str = "",
60
    ):
61
62
63
64
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
65
        self.bias = getattr(config, "attention_bias", True)
66

67
68
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
69
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
70
71
72
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

73
        self.query_key_value = QKVParallelLinear(
74
            config.hidden_size,
75
76
            self.head_size,
            self.total_num_heads,
77
            bias=self.bias,
78
            quant_config=quant_config,
79
80
81
82
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
83
            bias=self.bias,
84
            quant_config=quant_config,
85
        )
86
        scaling = self.head_size**-0.5
87
88
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
89
90
91
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        self.rotary_emb = get_rope(
93
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
96
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        )
98
99
100
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
101
                              cache_config=cache_config,
102
103
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
104
105
106

    def forward(
        self,
107
        position_ids: torch.Tensor,
108
109
110
111
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        q, k = self.rotary_emb(position_ids, q, k)
113
        attn_output = self.attn(q, k, v)
114
115
116
117
118
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
119

120
121
122
    def __init__(
        self,
        config: GPTNeoXConfig,
123
        quant_config: Optional[QuantizationConfig] = None,
124
    ):
125
        super().__init__()
126
127
128
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
129
            quant_config=quant_config,
130
131
132
133
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
134
            quant_config=quant_config,
135
        )
136
        self.act = get_act_fn(config.hidden_act)
137
138
139
140
141
142
143
144
145
146

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

147
148
149
    def __init__(
        self,
        config: GPTNeoXConfig,
150
        cache_config: Optional[CacheConfig] = None,
151
        quant_config: Optional[QuantizationConfig] = None,
152
        prefix: str = "",
153
    ):
154
155
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
156
157
158
159
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
160
161
162
163
        self.attention = GPTNeoXAttention(config,
                                          cache_config,
                                          quant_config,
                                          prefix=f"{prefix}.attention")
164
        self.mlp = GPTNeoXMLP(config, quant_config)
165
166
167

    def forward(
        self,
168
        position_ids: torch.Tensor,
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


194
@support_torch_compile
195
class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
196

197
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
198
        super().__init__()
199
200
201
202
203

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

204
205
        self.config = config

206
207
208
209
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
210
211
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
212
213
            lambda prefix: GPTNeoXLayer(
                config, cache_config, quant_config, prefix=prefix),
214
215
            prefix=f"{prefix}.layers",
        )
216
217
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
218
219
220
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
221

222
223
224
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_in(input_ids)

225
226
    def forward(
        self,
227
228
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
229
        intermediate_tensors: Optional[IntermediateTensors],
230
        inputs_embeds: Optional[torch.Tensor] = None,
231
232
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
233
234
235
236
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
237
238
        else:
            hidden_states = intermediate_tensors["hidden_states"]
239
        for layer in islice(self.layers, self.start_layer, self.end_layer):
240
            hidden_states = layer(position_ids, hidden_states)
241
242
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
243
244
245
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states

246
247
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
248
        params_dict = dict(self.named_parameters())
249
        loaded_params: set[str] = set()
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        for name, loaded_weight in weights:
            if ("attention.bias" in name or "attention.masked_bias" in name
                    or "rotary_emb.inv_freq" in name):
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]

            if "query_key_value" in name:
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
                # Thus, we need weight conversion.
                output_dim = getattr(param, "output_dim", None)
                num_heads = self.config.num_attention_heads
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

285

286
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
287

288
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
289
        super().__init__()
290
291
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
292
        self.config = config
293
        self.quant_config = quant_config
294
295
        self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "gpt_neox"))
296
        self.embed_out = ParallelLMHead(
297
            config.vocab_size,
298
            config.hidden_size,
299
            quant_config=quant_config,
300
        )
301
302
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
303
        self.logits_processor = LogitsProcessor(config.vocab_size)
304
305
        self.make_empty_intermediate_tensors = (
            self.gpt_neox.make_empty_intermediate_tensors)
306

307
308
309
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.gpt_neox.get_input_embeddings(input_ids)

310
311
    def forward(
        self,
312
313
        input_ids: torch.Tensor,
        positions: torch.Tensor,
314
        intermediate_tensors: Optional[IntermediateTensors] = None,
315
        inputs_embeds: Optional[torch.Tensor] = None,
316
    ) -> Union[torch.Tensor, IntermediateTensors]:
317
318
        hidden_states = self.gpt_neox(input_ids, positions,
                                      intermediate_tensors, inputs_embeds)
319
320
        return hidden_states

321
322
323
324
325
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
326
        logits = self.logits_processor(self.embed_out, hidden_states,
327
328
329
                                       sampling_metadata)
        return logits

330
331
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
332
333
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)