gpt_neox.py 11.2 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple
20
21
22

import torch
from torch import nn
23
24
from transformers import GPTNeoXConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.config import CacheConfig
27
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
34
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
35
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.sampler import Sampler
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import SamplerOutput
42
43
44
45


class GPTNeoXAttention(nn.Module):

46
47
48
    def __init__(
        self,
        config: GPTNeoXConfig,
49
        cache_config: Optional[CacheConfig] = None,
50
        quant_config: Optional[QuantizationConfig] = None,
51
    ):
52
53
54
55
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
56
        self.bias = getattr(config, "attention_bias", True)
57

58
59
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
60
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
61
62
63
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

64
        self.query_key_value = QKVParallelLinear(
65
            config.hidden_size,
66
67
            self.head_size,
            self.total_num_heads,
68
            bias=self.bias,
69
            quant_config=quant_config,
70
71
72
73
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
74
            bias=self.bias,
75
            quant_config=quant_config,
76
        )
77
        scaling = self.head_size**-0.5
78
79
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
80
81
82
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        self.rotary_emb = get_rope(
84
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
87
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        )
89
90
91
92
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
                              cache_config=cache_config)
93
94
95

    def forward(
        self,
96
        position_ids: torch.Tensor,
97
        hidden_states: torch.Tensor,
98
99
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
100
101
102
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        q, k = self.rotary_emb(position_ids, q, k)
104
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
105
106
107
108
109
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
110

111
112
113
    def __init__(
        self,
        config: GPTNeoXConfig,
114
        quant_config: Optional[QuantizationConfig] = None,
115
    ):
116
        super().__init__()
117
118
119
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
120
            quant_config=quant_config,
121
122
123
124
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
125
            quant_config=quant_config,
126
        )
127
128
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
129
130
131
132
133
134
135
136
137
138

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

139
140
141
    def __init__(
        self,
        config: GPTNeoXConfig,
142
        cache_config: Optional[CacheConfig] = None,
143
        quant_config: Optional[QuantizationConfig] = None,
144
    ):
145
146
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
147
148
149
150
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
151
        self.attention = GPTNeoXAttention(config, cache_config, quant_config)
152
        self.mlp = GPTNeoXMLP(config, quant_config)
153
154
155

    def forward(
        self,
156
        position_ids: torch.Tensor,
157
        hidden_states: torch.Tensor,
158
159
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
160
161
162
163
164
165
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
166
            attn_metadata=attn_metadata,
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
187

188
189
190
    def __init__(
        self,
        config: GPTNeoXConfig,
191
        cache_config: Optional[CacheConfig] = None,
192
        quant_config: Optional[QuantizationConfig] = None,
193
    ):
194
195
196
        super().__init__()
        self.config = config

197
198
199
200
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
201
        self.layers = nn.ModuleList([
202
            GPTNeoXLayer(config, cache_config, quant_config)
203
204
            for _ in range(config.num_hidden_layers)
        ])
205
206
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
207
208
209

    def forward(
        self,
210
211
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
212
213
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
214
215
216
217
218
219
220
221
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
222
                attn_metadata,
223
224
225
226
227
228
229
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

230
231
232
    def __init__(
        self,
        config,
233
        cache_config: Optional[CacheConfig] = None,
234
        quant_config: Optional[QuantizationConfig] = None,
235
    ):
236
237
        super().__init__()
        self.config = config
238
        self.quant_config = quant_config
239
        self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
240
        self.embed_out = ParallelLMHead(
241
            config.vocab_size,
242
            config.hidden_size,
243
        )
244
245
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
246
247
248

    def forward(
        self,
249
250
        input_ids: torch.Tensor,
        positions: torch.Tensor,
251
252
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
253
    ) -> torch.Tensor:
254
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
255
                                      attn_metadata)
256
257
        return hidden_states

258
259
260
261
262
263
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.embed_out.weight, hidden_states,
                                       sampling_metadata)
        return logits

264
265
    def sample(
        self,
266
        logits: torch.Tensor,
267
        sampling_metadata: SamplingMetadata,
268
    ) -> Optional[SamplerOutput]:
269
        next_tokens = self.sampler(logits, sampling_metadata)
270
271
        return next_tokens

272
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
273
        params_dict = dict(self.named_parameters())
274
        for name, loaded_weight in weights:
275
            if ("attention.bias" in name or "attention.masked_bias" in name
276
                    or "rotary_emb.inv_freq" in name):
277
                continue
278
279
280
281
282
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
283
284
            param = params_dict[name]

285
            if "query_key_value" in name:
286
287
288
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
289
                # Thus, we need weight conversion.
290
                output_dim = getattr(param, "output_dim", None)
291
                num_heads = self.config.num_attention_heads
292
293
294
295
296
297
298
299
300
301
302
303
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)