gpt_neox.py 11 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
19
from typing import List, Optional, Tuple
20
21
22

import torch
from torch import nn
23
24
from transformers import GPTNeoXConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.model_executor.layers.attention import PagedAttention
28
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.sampler import Sampler
34
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.parallel_utils.parallel_state import (
37
    get_tensor_model_parallel_world_size)
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
40
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
41
from vllm.sequence import SamplerOutput
42
43
44
45
46
47

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPTNeoXAttention(nn.Module):

48
49
50
51
52
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
53
54
55
56
57
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

58
59
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
60
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
61
62
63
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

64
        self.query_key_value = QKVParallelLinear(
65
            config.hidden_size,
66
67
68
            self.head_size,
            self.total_num_heads,
            linear_method=linear_method,
69
70
71
72
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
73
            linear_method=linear_method,
74
        )
75

76
        scaling = self.head_size**-0.5
77
78
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
79
80
81
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        self.rotary_emb = get_rope(
83
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
86
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
        )
        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
89
90
91

    def forward(
        self,
92
        position_ids: torch.Tensor,
93
94
95
96
97
98
99
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
        q, k = self.rotary_emb(position_ids, q, k)
101
        k_cache, v_cache = kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
104
105
106
107
108
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
109

110
111
112
113
114
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
115
        super().__init__()
116
117
118
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
119
            linear_method=linear_method,
120
121
122
123
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
124
            linear_method=linear_method,
125
        )
126
127
128
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
129
130
131
132
133
134
135
136
137
138

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

139
140
141
142
143
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
144
145
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
146
147
148
149
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
150
151
        self.attention = GPTNeoXAttention(config, linear_method)
        self.mlp = GPTNeoXMLP(config, linear_method)
152
153
154

    def forward(
        self,
155
        position_ids: torch.Tensor,
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
188

189
190
191
192
193
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
194
195
196
        super().__init__()
        self.config = config

197
198
199
200
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
201
202
203
204
        self.layers = nn.ModuleList([
            GPTNeoXLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
205
206
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
207
208
209

    def forward(
        self,
210
211
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
212
213
214
215
216
217
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
218
            cache_event = None if cache_events is None else cache_events[i]
219
220
221
222
223
224
225
226
227
228
229
230
231
232
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

233
234
235
236
237
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
238
239
        super().__init__()
        self.config = config
240
241
242
        self.linear_method = linear_method
        self.gpt_neox = GPTNeoXModel(config, linear_method)
        self.embed_out = ParallelLMHead(
243
            config.vocab_size,
244
            config.hidden_size,
245
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
246
        self.sampler = Sampler(config.vocab_size)
247
248
249

    def forward(
        self,
250
251
        input_ids: torch.Tensor,
        positions: torch.Tensor,
252
253
254
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
255
    ) -> torch.Tensor:
256
257
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
                                      input_metadata, cache_events)
258
259
260
261
262
263
264
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
265
        next_tokens = self.sampler(self.embed_out.weight, hidden_states,
266
                                   sampling_metadata)
267
268
        return next_tokens

269
270
    def load_weights(self,
                     model_name_or_path: str,
271
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
272
273
                     load_format: str = "auto",
                     revision: Optional[str] = None):
274
        params_dict = dict(self.named_parameters())
275
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
276
                model_name_or_path, cache_dir, load_format, revision):
277
            if ("attention.bias" in name or "attention.masked_bias" in name
278
                    or "rotary_emb.inv_freq" in name):
279
                continue
280
281
            param = params_dict[name]

282
            if "query_key_value" in name:
283
284
285
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
286
                # Thus, we need weight conversion.
287
                output_dim = getattr(param, "output_dim", None)
288
                num_heads = self.config.num_attention_heads
289
290
291
292
293
294
295
296
297
298
299
300
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)