gpt_neox.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
21
from collections.abc import Iterable
22
from itertools import islice
23
from typing import Optional, Union
24

gaoqiong's avatar
gaoqiong committed
25
26
import os
import re
27
28
import torch
from torch import nn
29
30
from transformers import GPTNeoXConfig

31
from vllm.attention import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.activation import get_act_fn
36
37
38
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
    ParallelLMHead, VocabParallelEmbedding)
44
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
from vllm.sequence import IntermediateTensors
47

48
from .interfaces import SupportsPP
49
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
50
51
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
gaoqiong's avatar
gaoqiong committed
52
from vllm import _custom_ops as ops
53
54
55

class GPTNeoXAttention(nn.Module):

56
57
58
    def __init__(
        self,
        config: GPTNeoXConfig,
59
        cache_config: Optional[CacheConfig] = None,
60
        quant_config: Optional[QuantizationConfig] = None,
61
        prefix: str = "",
62
    ):
63
64
65
66
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
67
        self.bias = getattr(config, "attention_bias", True)
68

69
70
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
71
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
72
73
74
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

75
        self.query_key_value = QKVParallelLinear(
76
            config.hidden_size,
77
78
            self.head_size,
            self.total_num_heads,
79
            bias=self.bias,
80
            quant_config=quant_config,
81
82
83
84
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
85
            bias=self.bias,
86
            quant_config=quant_config,
87
        )
88
        scaling = self.head_size**-0.5
89
90
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
91
92
93
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
94
        self.rotary_emb = get_rope(
95
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
98
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        )
100
101
102
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
103
                              cache_config=cache_config,
104
105
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
106
107
108

    def forward(
        self,
109
        position_ids: torch.Tensor,
110
111
112
113
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        q, k = self.rotary_emb(position_ids, q, k)
115
        attn_output = self.attn(q, k, v)
116
117
118
119
120
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
121

122
123
124
    def __init__(
        self,
        config: GPTNeoXConfig,
125
        quant_config: Optional[QuantizationConfig] = None,
126
    ):
127
        super().__init__()
128
129
130
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
131
            quant_config=quant_config,
132
133
134
135
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
136
            quant_config=quant_config,
137
        )
138
        self.act = get_act_fn(config.hidden_act)
139
140
141
142
143
144
145
146
147
148

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

149
150
151
    def __init__(
        self,
        config: GPTNeoXConfig,
152
        cache_config: Optional[CacheConfig] = None,
153
        quant_config: Optional[QuantizationConfig] = None,
154
        prefix: str = "",
155
    ):
156
157
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
158
159
160
161
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
162
163
164
165
        self.attention = GPTNeoXAttention(config,
                                          cache_config,
                                          quant_config,
                                          prefix=f"{prefix}.attention")
166
        self.mlp = GPTNeoXMLP(config, quant_config)
167
168
169

    def forward(
        self,
170
        position_ids: torch.Tensor,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


196
@support_torch_compile
197
class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
198

199
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
200
        super().__init__()
201
202
203
204
205

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

206
207
        self.config = config

208
209
210
211
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
212
213
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
214
215
            lambda prefix: GPTNeoXLayer(
                config, cache_config, quant_config, prefix=prefix),
216
217
            prefix=f"{prefix}.layers",
        )
218
219
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
220
221
222
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
223
224
225
226
227
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
228

229
230
231
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_in(input_ids)

232
233
    def forward(
        self,
234
235
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
236
        intermediate_tensors: Optional[IntermediateTensors],
237
        inputs_embeds: Optional[torch.Tensor] = None,
238
239
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
240
241
242
243
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
244
245
        else:
            hidden_states = intermediate_tensors["hidden_states"]
246
        for layer in islice(self.layers, self.start_layer, self.end_layer):
247
            hidden_states = layer(position_ids, hidden_states)
248
249
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
250
251
252
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states

253
254
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
255
        params_dict = dict(self.named_parameters())
256
        loaded_params: set[str] = set()
257
        for name, loaded_weight in weights:
258
            if ("attention.bias" in name or "attention.masked_bias" in name
259
                    or "rotary_emb.inv_freq" in name):
260
                continue
261
262
263
264
265
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
266
267
            if is_pp_missing_parameter(name, self):
                continue
268
269
            param = params_dict[name]

270
            if "query_key_value" in name:
271
272
273
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
274
                # Thus, we need weight conversion.
275
                output_dim = getattr(param, "output_dim", None)
276
                num_heads = self.config.num_attention_heads
277
278
279
280
281
282
283
284
285
286
287
288
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
289
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
290
                         
291
        return loaded_params
292

293

294
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
295

296
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
297
        super().__init__()
298
299
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
300
        self.config = config
301
        self.quant_config = quant_config
302
303
        self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "gpt_neox"))
304
        self.embed_out = ParallelLMHead(
305
            config.vocab_size,
306
            config.hidden_size,
307
            quant_config=quant_config,
308
        )
309
310
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
311
        self.logits_processor = LogitsProcessor(config.vocab_size)
312
313
        self.make_empty_intermediate_tensors = (
            self.gpt_neox.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
314
        
315
316
317
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.gpt_neox.get_input_embeddings(input_ids)

318
319
    def forward(
        self,
320
321
        input_ids: torch.Tensor,
        positions: torch.Tensor,
322
        intermediate_tensors: Optional[IntermediateTensors] = None,
323
        inputs_embeds: Optional[torch.Tensor] = None,
324
    ) -> Union[torch.Tensor, IntermediateTensors]:
325
326
        hidden_states = self.gpt_neox(input_ids, positions,
                                      intermediate_tensors, inputs_embeds)
327
328
        return hidden_states

329
330
331
332
333
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
334
        logits = self.logits_processor(self.embed_out, hidden_states,
335
336
337
                                       sampling_metadata)
        return logits

338
339
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
340
341
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)