gpt_neox.py 10.8 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
19
from typing import List, Optional, Tuple
20
21
22

import torch
from torch import nn
23
24
from transformers import GPTNeoXConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.model_executor.layers.attention import PagedAttention
28
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.sampler import Sampler
34
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.parallel_utils.parallel_state import (
37
38
39
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
40
from vllm.sequence import SamplerOutput
41
42
43
44
45
46

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPTNeoXAttention(nn.Module):

47
48
49
50
51
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
52
53
54
55
56
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

57
58
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
59
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
60
61
62
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

63
        self.query_key_value = QKVParallelLinear(
64
            config.hidden_size,
65
66
67
            self.head_size,
            self.total_num_heads,
            linear_method=linear_method,
68
69
70
71
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
72
            linear_method=linear_method,
73
        )
74

75
        scaling = self.head_size**-0.5
76
77
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
78
79
80
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        self.rotary_emb = get_rope(
82
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
85
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
        )
        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
88
89
90

    def forward(
        self,
91
        position_ids: torch.Tensor,
92
93
94
95
96
97
98
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        q, k = self.rotary_emb(position_ids, q, k)
100
        k_cache, v_cache = kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
103
104
105
106
107
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
108

109
110
111
112
113
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
114
        super().__init__()
115
116
117
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
118
            linear_method=linear_method,
119
120
121
122
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
123
            linear_method=linear_method,
124
        )
125
126
127
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
128
129
130
131
132
133
134
135
136
137

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

138
139
140
141
142
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
143
144
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
145
146
147
148
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
149
150
        self.attention = GPTNeoXAttention(config, linear_method)
        self.mlp = GPTNeoXMLP(config, linear_method)
151
152
153

    def forward(
        self,
154
        position_ids: torch.Tensor,
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
187

188
189
190
191
192
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
193
194
195
        super().__init__()
        self.config = config

196
197
198
199
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
200
201
202
203
        self.layers = nn.ModuleList([
            GPTNeoXLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
204
205
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
206
207
208

    def forward(
        self,
209
210
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
211
212
213
214
215
216
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
217
            cache_event = None if cache_events is None else cache_events[i]
218
219
220
221
222
223
224
225
226
227
228
229
230
231
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

232
233
234
235
236
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
237
238
        super().__init__()
        self.config = config
239
240
241
        self.linear_method = linear_method
        self.gpt_neox = GPTNeoXModel(config, linear_method)
        self.embed_out = ParallelLMHead(
242
            config.vocab_size,
243
            config.hidden_size,
244
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
245
        self.sampler = Sampler(config.vocab_size)
246
247
248

    def forward(
        self,
249
250
        input_ids: torch.Tensor,
        positions: torch.Tensor,
251
252
253
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
254
    ) -> SamplerOutput:
255
256
257
258
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
                                      input_metadata, cache_events)
        next_tokens = self.sampler(self.embed_out.weight, hidden_states,
                                   input_metadata)
259
260
        return next_tokens

261
262
    def load_weights(self,
                     model_name_or_path: str,
263
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
264
265
                     load_format: str = "auto",
                     revision: Optional[str] = None):
266
        params_dict = dict(self.named_parameters())
267
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
268
                model_name_or_path, cache_dir, load_format, revision):
269
            if ("attention.bias" in name or "attention.masked_bias" in name
270
                    or "rotary_emb.inv_freq" in name):
271
                continue
272
273
            param = params_dict[name]

274
            if "query_key_value" in name:
275
276
277
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
278
                # Thus, we need weight conversion.
279
                output_dim = getattr(param, "output_dim", None)
280
                num_heads = self.config.num_attention_heads
281
282
283
284
285
286
287
288
289
290
291
292
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)