gpt_neox.py 11.5 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple
20
21
22

import torch
from torch import nn
23
24
from transformers import GPTNeoXConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.config import CacheConfig
27
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
34
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
35
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.sampler import Sampler
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors, SamplerOutput
42
43
44
45


class GPTNeoXAttention(nn.Module):

46
47
48
    def __init__(
        self,
        config: GPTNeoXConfig,
49
        cache_config: Optional[CacheConfig] = None,
50
        quant_config: Optional[QuantizationConfig] = None,
51
    ):
52
53
54
55
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
56
        self.bias = getattr(config, "attention_bias", True)
57

58
59
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
60
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
61
62
63
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

64
        self.query_key_value = QKVParallelLinear(
65
            config.hidden_size,
66
67
            self.head_size,
            self.total_num_heads,
68
            bias=self.bias,
69
            quant_config=quant_config,
70
71
72
73
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
74
            bias=self.bias,
75
            quant_config=quant_config,
76
        )
77
        scaling = self.head_size**-0.5
78
79
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
80
81
82
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        self.rotary_emb = get_rope(
84
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
87
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        )
89
90
91
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
92
93
                              cache_config=cache_config,
                              quant_config=quant_config)
94
95
96

    def forward(
        self,
97
        position_ids: torch.Tensor,
98
        hidden_states: torch.Tensor,
99
100
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
101
102
103
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        q, k = self.rotary_emb(position_ids, q, k)
105
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
106
107
108
109
110
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
111

112
113
114
    def __init__(
        self,
        config: GPTNeoXConfig,
115
        quant_config: Optional[QuantizationConfig] = None,
116
    ):
117
        super().__init__()
118
119
120
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
121
            quant_config=quant_config,
122
123
124
125
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
126
            quant_config=quant_config,
127
        )
128
129
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
130
131
132
133
134
135
136
137
138
139

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

140
141
142
    def __init__(
        self,
        config: GPTNeoXConfig,
143
        cache_config: Optional[CacheConfig] = None,
144
        quant_config: Optional[QuantizationConfig] = None,
145
    ):
146
147
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
148
149
150
151
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
152
        self.attention = GPTNeoXAttention(config, cache_config, quant_config)
153
        self.mlp = GPTNeoXMLP(config, quant_config)
154
155
156

    def forward(
        self,
157
        position_ids: torch.Tensor,
158
        hidden_states: torch.Tensor,
159
160
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
161
162
163
164
165
166
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
167
            attn_metadata=attn_metadata,
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
188

189
190
191
    def __init__(
        self,
        config: GPTNeoXConfig,
192
        cache_config: Optional[CacheConfig] = None,
193
        quant_config: Optional[QuantizationConfig] = None,
194
    ):
195
196
197
        super().__init__()
        self.config = config

198
199
200
201
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
202
        self.layers = nn.ModuleList([
203
            GPTNeoXLayer(config, cache_config, quant_config)
204
205
            for _ in range(config.num_hidden_layers)
        ])
206
207
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
208
209
210

    def forward(
        self,
211
212
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
213
214
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
215
216
217
218
219
220
221
222
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
223
                attn_metadata,
224
225
226
227
228
229
230
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

231
232
    def __init__(
        self,
233
        config: GPTNeoXConfig,
234
        cache_config: Optional[CacheConfig] = None,
235
        quant_config: Optional[QuantizationConfig] = None,
236
    ):
237
238
        super().__init__()
        self.config = config
239
        self.quant_config = quant_config
240
        self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
241
        self.embed_out = ParallelLMHead(
242
            config.vocab_size,
243
            config.hidden_size,
244
            quant_config=quant_config,
245
        )
246
247
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
248
249
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
250
251
252

    def forward(
        self,
253
254
        input_ids: torch.Tensor,
        positions: torch.Tensor,
255
256
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
257
        intermediate_tensors: Optional[IntermediateTensors] = None,
258
    ) -> torch.Tensor:
259
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
260
                                      attn_metadata)
261
262
        return hidden_states

263
264
265
266
267
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
268
        logits = self.logits_processor(self.embed_out, hidden_states,
269
270
271
                                       sampling_metadata)
        return logits

272
273
    def sample(
        self,
274
        logits: torch.Tensor,
275
        sampling_metadata: SamplingMetadata,
276
    ) -> Optional[SamplerOutput]:
277
        next_tokens = self.sampler(logits, sampling_metadata)
278
279
        return next_tokens

280
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
281
        params_dict = dict(self.named_parameters())
282
        for name, loaded_weight in weights:
283
            if ("attention.bias" in name or "attention.masked_bias" in name
284
                    or "rotary_emb.inv_freq" in name):
285
                continue
286
287
288
289
290
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
291
292
            param = params_dict[name]

293
            if "query_key_value" in name:
294
295
296
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
297
                # Thus, we need weight conversion.
298
                output_dim = getattr(param, "output_dim", None)
299
                num_heads = self.config.num_attention_heads
300
301
302
303
304
305
306
307
308
309
310
311
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)