gpt_neox.py 10.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple
20
21
22

import torch
from torch import nn
23
24
from transformers import GPTNeoXConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.model_executor.layers.activation import get_act_fn
28
29
30
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
31
from vllm.model_executor.layers.logits_processor import LogitsProcessor
32
33
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
34
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.sampler import Sampler
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
    ParallelLMHead, VocabParallelEmbedding)
38
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
from vllm.sequence import SamplerOutput
41
42
43
44


class GPTNeoXAttention(nn.Module):

45
46
47
    def __init__(
        self,
        config: GPTNeoXConfig,
48
        quant_config: Optional[QuantizationConfig] = None,
49
    ):
50
51
52
53
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
54
        self.bias = getattr(config, "attention_bias", True)
55

56
57
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
58
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
59
60
61
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

62
        self.query_key_value = QKVParallelLinear(
63
            config.hidden_size,
64
65
            self.head_size,
            self.total_num_heads,
66
            bias=self.bias,
67
            quant_config=quant_config,
68
69
70
71
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
72
            bias=self.bias,
73
            quant_config=quant_config,
74
        )
75
        scaling = self.head_size**-0.5
76
77
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
78
79
80
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        self.rotary_emb = get_rope(
82
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
85
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        )
87
        self.attn = Attention(self.num_heads, self.head_size, scaling)
88
89
90

    def forward(
        self,
91
        position_ids: torch.Tensor,
92
        hidden_states: torch.Tensor,
93
94
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
95
96
97
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        q, k = self.rotary_emb(position_ids, q, k)
99
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
100
101
102
103
104
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
105

106
107
108
    def __init__(
        self,
        config: GPTNeoXConfig,
109
        quant_config: Optional[QuantizationConfig] = None,
110
    ):
111
        super().__init__()
112
113
114
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
115
            quant_config=quant_config,
116
117
118
119
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
120
            quant_config=quant_config,
121
        )
122
        quant_config = getattr(quant_config, "quant_config", None)
123
124
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
125
126
127
128
129
130
131
132
133
134

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

135
136
137
    def __init__(
        self,
        config: GPTNeoXConfig,
138
        quant_config: Optional[QuantizationConfig] = None,
139
    ):
140
141
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
142
143
144
145
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
146
147
        self.attention = GPTNeoXAttention(config, quant_config)
        self.mlp = GPTNeoXMLP(config, quant_config)
148
149
150

    def forward(
        self,
151
        position_ids: torch.Tensor,
152
        hidden_states: torch.Tensor,
153
154
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
155
156
157
158
159
160
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
161
            attn_metadata=attn_metadata,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
182

183
184
185
    def __init__(
        self,
        config: GPTNeoXConfig,
186
        quant_config: Optional[QuantizationConfig] = None,
187
    ):
188
189
190
        super().__init__()
        self.config = config

191
192
193
194
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
195
        self.layers = nn.ModuleList([
196
            GPTNeoXLayer(config, quant_config)
197
198
            for _ in range(config.num_hidden_layers)
        ])
199
200
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
201
202
203

    def forward(
        self,
204
205
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
206
207
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
208
209
210
211
212
213
214
215
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
216
                attn_metadata,
217
218
219
220
221
222
223
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

224
225
226
    def __init__(
        self,
        config,
227
        quant_config: Optional[QuantizationConfig] = None,
228
    ):
229
230
        super().__init__()
        self.config = config
231
232
        self.quant_config = quant_config
        self.gpt_neox = GPTNeoXModel(config, quant_config)
233
        self.embed_out = ParallelLMHead(
234
            config.vocab_size,
235
            config.hidden_size,
236
        )
237
238
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
239
240
241

    def forward(
        self,
242
243
        input_ids: torch.Tensor,
        positions: torch.Tensor,
244
245
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
246
    ) -> torch.Tensor:
247
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
248
                                      attn_metadata)
249
250
        return hidden_states

251
252
253
254
255
256
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.embed_out.weight, hidden_states,
                                       sampling_metadata)
        return logits

257
258
    def sample(
        self,
259
        logits: torch.Tensor,
260
        sampling_metadata: SamplingMetadata,
261
    ) -> Optional[SamplerOutput]:
262
        next_tokens = self.sampler(logits, sampling_metadata)
263
264
        return next_tokens

265
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
266
        params_dict = dict(self.named_parameters())
267
        for name, loaded_weight in weights:
268
            if ("attention.bias" in name or "attention.masked_bias" in name
269
                    or "rotary_emb.inv_freq" in name):
270
                continue
271
272
273
274
275
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
276
277
            param = params_dict[name]

278
            if "query_key_value" in name:
279
280
281
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
282
                # Thus, we need weight conversion.
283
                output_dim = getattr(param, "output_dim", None)
284
                num_heads = self.config.num_attention_heads
285
286
287
288
289
290
291
292
293
294
295
296
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)