gpt_neox.py 11 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
19
from typing import List, Optional
20
21
22

import torch
from torch import nn
23
24
from transformers import GPTNeoXConfig

25
from vllm.attention import Attention, AttentionMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
26
from vllm.model_executor.layers.activation import get_act_fn
27
28
29
30
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
31
from vllm.model_executor.layers.logits_processor import LogitsProcessor
32
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.sampler import Sampler
34
from vllm.model_executor.layers.vocab_parallel_embedding import (
35
    ParallelLMHead, VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.parallel_utils.parallel_state import (
37
    get_tensor_model_parallel_world_size)
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
40
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
41
from vllm.sequence import SamplerOutput
42
43
44
45


class GPTNeoXAttention(nn.Module):

46
47
48
49
50
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
51
52
53
54
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
55
        self.bias = getattr(config, "attention_bias", True)
56

57
58
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
59
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
60
61
62
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

63
        self.query_key_value = QKVParallelLinear(
64
            config.hidden_size,
65
66
            self.head_size,
            self.total_num_heads,
67
            bias=self.bias,
68
            linear_method=linear_method,
69
70
71
72
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
73
            bias=self.bias,
74
            linear_method=linear_method,
75
        )
76
        scaling = self.head_size**-0.5
77
78
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
79
80
81
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        self.rotary_emb = get_rope(
83
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
86
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        )
88
        self.attn = Attention(self.num_heads, self.head_size, scaling)
89
90
91

    def forward(
        self,
92
        position_ids: torch.Tensor,
93
        hidden_states: torch.Tensor,
94
95
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
96
97
98
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        q, k = self.rotary_emb(position_ids, q, k)
100
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
101
102
103
104
105
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
106

107
108
109
110
111
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
112
        super().__init__()
113
114
115
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
116
            linear_method=linear_method,
117
118
119
120
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
121
            linear_method=linear_method,
122
        )
123
124
125
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
126
127
128
129
130
131
132
133
134
135

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

136
137
138
139
140
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
141
142
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
143
144
145
146
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
147
148
        self.attention = GPTNeoXAttention(config, linear_method)
        self.mlp = GPTNeoXMLP(config, linear_method)
149
150
151

    def forward(
        self,
152
        position_ids: torch.Tensor,
153
        hidden_states: torch.Tensor,
154
155
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
156
157
158
159
160
161
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
162
            attn_metadata=attn_metadata,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
183

184
185
186
187
188
    def __init__(
        self,
        config: GPTNeoXConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
189
190
191
        super().__init__()
        self.config = config

192
193
194
195
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
196
197
198
199
        self.layers = nn.ModuleList([
            GPTNeoXLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
200
201
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
202
203
204

    def forward(
        self,
205
206
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
207
208
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
209
210
211
212
213
214
215
216
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
217
                attn_metadata,
218
219
220
221
222
223
224
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

225
226
227
228
229
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
230
231
        super().__init__()
        self.config = config
232
233
234
        self.linear_method = linear_method
        self.gpt_neox = GPTNeoXModel(config, linear_method)
        self.embed_out = ParallelLMHead(
235
            config.vocab_size,
236
            config.hidden_size,
237
        )
238
239
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
240
241
242

    def forward(
        self,
243
244
        input_ids: torch.Tensor,
        positions: torch.Tensor,
245
246
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
247
    ) -> torch.Tensor:
248
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
249
                                      attn_metadata)
250
251
        return hidden_states

252
253
254
255
256
257
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.embed_out.weight, hidden_states,
                                       sampling_metadata)
        return logits

258
259
    def sample(
        self,
260
        logits: torch.Tensor,
261
        sampling_metadata: SamplingMetadata,
262
    ) -> Optional[SamplerOutput]:
263
        next_tokens = self.sampler(logits, sampling_metadata)
264
265
        return next_tokens

266
267
    def load_weights(self,
                     model_name_or_path: str,
268
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
269
270
                     load_format: str = "auto",
                     revision: Optional[str] = None):
271
        params_dict = dict(self.named_parameters())
272
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
273
                model_name_or_path, cache_dir, load_format, revision):
274
            if ("attention.bias" in name or "attention.masked_bias" in name
275
                    or "rotary_emb.inv_freq" in name):
276
                continue
277
278
            param = params_dict[name]

279
            if "query_key_value" in name:
280
281
282
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
283
                # Thus, we need weight conversion.
284
                output_dim = getattr(param, "output_dim", None)
285
                num_heads = self.config.num_attention_heads
286
287
288
289
290
291
292
293
294
295
296
297
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)