gpt_neox.py 11.2 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
19
from typing import List, Optional
20
21
22

import torch
from torch import nn
23
24
from transformers import GPTNeoXConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.model_executor.layers.activation import get_act_fn
28
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.sampler import Sampler
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
    ParallelLMHead, VocabParallelEmbedding)
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
38
39
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
40
from vllm.sequence import SamplerOutput
41
42
43
44


class GPTNeoXAttention(nn.Module):

45
46
47
48
49
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
50
51
52
53
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
54
        self.bias = getattr(config, "attention_bias", True)
55

56
57
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
58
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
59
60
61
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

62
        self.query_key_value = QKVParallelLinear(
63
            config.hidden_size,
64
65
            self.head_size,
            self.total_num_heads,
66
            bias=self.bias,
67
            linear_method=linear_method,
68
69
70
71
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
72
            bias=self.bias,
73
            linear_method=linear_method,
74
        )
75
        scaling = self.head_size**-0.5
76
77
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
78
79
80
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        self.rotary_emb = get_rope(
82
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
85
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        )
87
        self.attn = Attention(self.num_heads, self.head_size, scaling)
88
89
90

    def forward(
        self,
91
        position_ids: torch.Tensor,
92
        hidden_states: torch.Tensor,
93
94
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
95
96
97
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        q, k = self.rotary_emb(position_ids, q, k)
99
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
100
101
102
103
104
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
105

106
107
108
109
110
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
111
        super().__init__()
112
113
114
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
115
            linear_method=linear_method,
116
117
118
119
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
120
            linear_method=linear_method,
121
        )
122
123
124
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
125
126
127
128
129
130
131
132
133
134

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

135
136
137
138
139
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
140
141
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
142
143
144
145
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
146
147
        self.attention = GPTNeoXAttention(config, linear_method)
        self.mlp = GPTNeoXMLP(config, linear_method)
148
149
150

    def forward(
        self,
151
        position_ids: torch.Tensor,
152
        hidden_states: torch.Tensor,
153
154
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
155
156
157
158
159
160
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
161
            attn_metadata=attn_metadata,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
182

183
184
185
186
187
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
188
189
190
        super().__init__()
        self.config = config

191
192
193
194
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
195
196
197
198
        self.layers = nn.ModuleList([
            GPTNeoXLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
199
200
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
201
202
203

    def forward(
        self,
204
205
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
206
207
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
208
209
210
211
212
213
214
215
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
216
                attn_metadata,
217
218
219
220
221
222
223
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

224
225
226
227
228
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
229
230
        super().__init__()
        self.config = config
231
232
233
        self.linear_method = linear_method
        self.gpt_neox = GPTNeoXModel(config, linear_method)
        self.embed_out = ParallelLMHead(
234
            config.vocab_size,
235
            config.hidden_size,
236
        )
237
238
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
239
240
241

    def forward(
        self,
242
243
        input_ids: torch.Tensor,
        positions: torch.Tensor,
244
245
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
246
    ) -> torch.Tensor:
247
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
248
                                      attn_metadata)
249
250
        return hidden_states

251
252
253
254
255
256
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.embed_out.weight, hidden_states,
                                       sampling_metadata)
        return logits

257
258
    def sample(
        self,
259
        logits: torch.Tensor,
260
        sampling_metadata: SamplingMetadata,
261
    ) -> Optional[SamplerOutput]:
262
        next_tokens = self.sampler(logits, sampling_metadata)
263
264
        return next_tokens

265
266
    def load_weights(self,
                     model_name_or_path: str,
267
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
268
269
                     load_format: str = "auto",
                     revision: Optional[str] = None):
270
        params_dict = dict(self.named_parameters())
271
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
272
                model_name_or_path, cache_dir, load_format, revision):
273
            if ("attention.bias" in name or "attention.masked_bias" in name
274
                    or "rotary_emb.inv_freq" in name):
275
                continue
276
277
278
279
280
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
281
282
            param = params_dict[name]

283
            if "query_key_value" in name:
284
285
286
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
287
                # Thus, we need weight conversion.
288
                output_dim = getattr(param, "output_dim", None)
289
                num_heads = self.config.num_attention_heads
290
291
292
293
294
295
296
297
298
299
300
301
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)