gpt_neox.py 10.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple
20
21
22

import torch
from torch import nn
23
24
from transformers import GPTNeoXConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.model_executor.layers.activation import get_act_fn
28
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.sampler import Sampler
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
    ParallelLMHead, VocabParallelEmbedding)
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
from vllm.sequence import SamplerOutput
40
41
42
43


class GPTNeoXAttention(nn.Module):

44
45
46
47
48
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
49
50
51
52
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
53
        self.bias = getattr(config, "attention_bias", True)
54

55
56
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
57
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
58
59
60
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

61
        self.query_key_value = QKVParallelLinear(
62
            config.hidden_size,
63
64
            self.head_size,
            self.total_num_heads,
65
            bias=self.bias,
66
            linear_method=linear_method,
67
68
69
70
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
71
            bias=self.bias,
72
            linear_method=linear_method,
73
        )
74
        scaling = self.head_size**-0.5
75
76
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
77
78
79
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        self.rotary_emb = get_rope(
81
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
84
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        )
86
        self.attn = Attention(self.num_heads, self.head_size, scaling)
87
88
89

    def forward(
        self,
90
        position_ids: torch.Tensor,
91
        hidden_states: torch.Tensor,
92
93
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
94
95
96
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        q, k = self.rotary_emb(position_ids, q, k)
98
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
99
100
101
102
103
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
104

105
106
107
108
109
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
110
        super().__init__()
111
112
113
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
114
            linear_method=linear_method,
115
116
117
118
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
119
            linear_method=linear_method,
120
        )
121
122
123
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
124
125
126
127
128
129
130
131
132
133

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

134
135
136
137
138
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
139
140
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
141
142
143
144
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
145
146
        self.attention = GPTNeoXAttention(config, linear_method)
        self.mlp = GPTNeoXMLP(config, linear_method)
147
148
149

    def forward(
        self,
150
        position_ids: torch.Tensor,
151
        hidden_states: torch.Tensor,
152
153
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
154
155
156
157
158
159
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
160
            attn_metadata=attn_metadata,
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
181

182
183
184
185
186
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
187
188
189
        super().__init__()
        self.config = config

190
191
192
193
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
194
195
196
197
        self.layers = nn.ModuleList([
            GPTNeoXLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
198
199
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
200
201
202

    def forward(
        self,
203
204
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
205
206
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
207
208
209
210
211
212
213
214
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
215
                attn_metadata,
216
217
218
219
220
221
222
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

223
224
225
226
227
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
228
229
        super().__init__()
        self.config = config
230
231
232
        self.linear_method = linear_method
        self.gpt_neox = GPTNeoXModel(config, linear_method)
        self.embed_out = ParallelLMHead(
233
            config.vocab_size,
234
            config.hidden_size,
235
        )
236
237
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
238
239
240

    def forward(
        self,
241
242
        input_ids: torch.Tensor,
        positions: torch.Tensor,
243
244
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
245
    ) -> torch.Tensor:
246
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
247
                                      attn_metadata)
248
249
        return hidden_states

250
251
252
253
254
255
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.embed_out.weight, hidden_states,
                                       sampling_metadata)
        return logits

256
257
    def sample(
        self,
258
        logits: torch.Tensor,
259
        sampling_metadata: SamplingMetadata,
260
    ) -> Optional[SamplerOutput]:
261
        next_tokens = self.sampler(logits, sampling_metadata)
262
263
        return next_tokens

264
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
265
        params_dict = dict(self.named_parameters())
266
        for name, loaded_weight in weights:
267
            if ("attention.bias" in name or "attention.masked_bias" in name
268
                    or "rotary_emb.inv_freq" in name):
269
                continue
270
271
272
273
274
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
275
276
            param = params_dict[name]

277
            if "query_key_value" in name:
278
279
280
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
281
                # Thus, we need weight conversion.
282
                output_dim = getattr(param, "output_dim", None)
283
                num_heads = self.config.num_attention_heads
284
285
286
287
288
289
290
291
292
293
294
295
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)