gpt_neox.py 13.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
20
from typing import Iterable, Optional, Set, Tuple, Union
21

gaoqiong's avatar
gaoqiong committed
22
23
import os
import re
24
25
import torch
from torch import nn
26
27
from transformers import GPTNeoXConfig

28
from vllm.attention import Attention
29
from vllm.compilation.decorators import support_torch_compile
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.activation import get_act_fn
33
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
38
from vllm.model_executor.layers.rotary_embedding import get_rope
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
44

45
from .interfaces import SupportsPP
46
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
gaoqiong's avatar
gaoqiong committed
49
from vllm import _custom_ops as ops
50
51
52

class GPTNeoXAttention(nn.Module):

53
54
55
    def __init__(
        self,
        config: GPTNeoXConfig,
56
        cache_config: Optional[CacheConfig] = None,
57
        quant_config: Optional[QuantizationConfig] = None,
58
        prefix: str = "",
59
    ):
60
61
62
63
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
64
        self.bias = getattr(config, "attention_bias", True)
65

66
67
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
68
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
69
70
71
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

72
        self.query_key_value = QKVParallelLinear(
73
            config.hidden_size,
74
75
            self.head_size,
            self.total_num_heads,
76
            bias=self.bias,
77
            quant_config=quant_config,
78
79
80
81
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
82
            bias=self.bias,
83
            quant_config=quant_config,
84
        )
85
        scaling = self.head_size**-0.5
86
87
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
88
89
90
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        self.rotary_emb = get_rope(
92
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
95
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        )
97
98
99
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
100
                              cache_config=cache_config,
101
102
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
103
104
105

    def forward(
        self,
106
        position_ids: torch.Tensor,
107
108
109
110
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        q, k = self.rotary_emb(position_ids, q, k)
112
        attn_output = self.attn(q, k, v)
113
114
115
116
117
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
118

119
120
121
    def __init__(
        self,
        config: GPTNeoXConfig,
122
        quant_config: Optional[QuantizationConfig] = None,
123
    ):
124
        super().__init__()
125
126
127
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
128
            quant_config=quant_config,
129
130
131
132
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
133
            quant_config=quant_config,
134
        )
135
        self.act = get_act_fn(config.hidden_act)
136
137
138
139
140
141
142
143
144
145

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

146
147
148
    def __init__(
        self,
        config: GPTNeoXConfig,
149
        cache_config: Optional[CacheConfig] = None,
150
        quant_config: Optional[QuantizationConfig] = None,
151
        prefix: str = "",
152
    ):
153
154
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
155
156
157
158
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
159
160
161
162
        self.attention = GPTNeoXAttention(config,
                                          cache_config,
                                          quant_config,
                                          prefix=f"{prefix}.attention")
163
        self.mlp = GPTNeoXMLP(config, quant_config)
164
165
166

    def forward(
        self,
167
        position_ids: torch.Tensor,
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


193
@support_torch_compile
194
class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
195

196
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
197
        super().__init__()
198
199
200
201
202

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

203
204
        self.config = config

205
206
207
208
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
209
210
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
211
212
            lambda prefix: GPTNeoXLayer(
                config, cache_config, quant_config, prefix=prefix),
213
214
            prefix=f"{prefix}.layers",
        )
215
216
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
217
218
219
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
220
221
222
223
224
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
225

226
227
228
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_in(input_ids)

229
230
    def forward(
        self,
231
232
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
233
        intermediate_tensors: Optional[IntermediateTensors],
234
        inputs_embeds: Optional[torch.Tensor] = None,
235
236
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
237
238
239
240
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
241
242
        else:
            hidden_states = intermediate_tensors["hidden_states"]
243
244
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(position_ids, hidden_states)
245
246
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
247
248
249
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states

250
251
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
252
        params_dict = dict(self.named_parameters())
253
        loaded_params: Set[str] = set()
254
        for name, loaded_weight in weights:
255
            if ("attention.bias" in name or "attention.masked_bias" in name
256
                    or "rotary_emb.inv_freq" in name):
257
                continue
258
259
260
261
262
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
263
264
            if is_pp_missing_parameter(name, self):
                continue
265
266
            param = params_dict[name]

267
            if "query_key_value" in name:
268
269
270
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
271
                # Thus, we need weight conversion.
272
                output_dim = getattr(param, "output_dim", None)
273
                num_heads = self.config.num_attention_heads
274
275
276
277
278
279
280
281
282
283
284
285
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
286
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
287
                         
288
        return loaded_params
289

290

291
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
292

293
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
294
        super().__init__()
295
296
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
297
        self.config = config
298
        self.quant_config = quant_config
299
300
        self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "gpt_neox"))
301
        self.embed_out = ParallelLMHead(
302
            config.vocab_size,
303
            config.hidden_size,
304
            quant_config=quant_config,
305
        )
306
307
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
308
        self.logits_processor = LogitsProcessor(config.vocab_size)
309
310
        self.make_empty_intermediate_tensors = (
            self.gpt_neox.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
311
        
312
313
314
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.gpt_neox.get_input_embeddings(input_ids)

315
316
    def forward(
        self,
317
318
        input_ids: torch.Tensor,
        positions: torch.Tensor,
319
        intermediate_tensors: Optional[IntermediateTensors] = None,
320
        inputs_embeds: Optional[torch.Tensor] = None,
321
    ) -> Union[torch.Tensor, IntermediateTensors]:
322
323
        hidden_states = self.gpt_neox(input_ids, positions,
                                      intermediate_tensors, inputs_embeds)
324
325
        return hidden_states

326
327
328
329
330
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
331
        logits = self.logits_processor(self.embed_out, hidden_states,
332
333
334
                                       sampling_metadata)
        return logits

335
336
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
337
338
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)