gpt_neox.py 12.6 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
17
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
18
from typing import Iterable, List, Optional, Tuple, Union
19
20
21

import torch
from torch import nn
22
23
from transformers import GPTNeoXConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.compilation.decorators import support_torch_compile
26
from vllm.config import CacheConfig
27
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.rotary_embedding import get_rope
35
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
    ParallelLMHead, VocabParallelEmbedding)
38
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
from vllm.sequence import IntermediateTensors
41

42
43
44
45
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

46
47
48

class GPTNeoXAttention(nn.Module):

49
50
51
    def __init__(
        self,
        config: GPTNeoXConfig,
52
        cache_config: Optional[CacheConfig] = None,
53
        quant_config: Optional[QuantizationConfig] = None,
54
    ):
55
56
57
58
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
59
        self.bias = getattr(config, "attention_bias", True)
60

61
62
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
63
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
64
65
66
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

67
        self.query_key_value = QKVParallelLinear(
68
            config.hidden_size,
69
70
            self.head_size,
            self.total_num_heads,
71
            bias=self.bias,
72
            quant_config=quant_config,
73
74
75
76
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
77
            bias=self.bias,
78
            quant_config=quant_config,
79
        )
80
        scaling = self.head_size**-0.5
81
82
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
83
84
85
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        self.rotary_emb = get_rope(
87
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
90
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        )
92
93
94
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
95
96
                              cache_config=cache_config,
                              quant_config=quant_config)
97
98
99

    def forward(
        self,
100
        position_ids: torch.Tensor,
101
        hidden_states: torch.Tensor,
102
103
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
104
105
106
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        q, k = self.rotary_emb(position_ids, q, k)
108
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
109
110
111
112
113
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
114

115
116
117
    def __init__(
        self,
        config: GPTNeoXConfig,
118
        quant_config: Optional[QuantizationConfig] = None,
119
    ):
120
        super().__init__()
121
122
123
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
124
            quant_config=quant_config,
125
126
127
128
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
129
            quant_config=quant_config,
130
        )
131
132
        self.act = get_act_fn(config.hidden_act, quant_config,
                              config.intermediate_size)
133
134
135
136
137
138
139
140
141
142

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

143
144
145
    def __init__(
        self,
        config: GPTNeoXConfig,
146
        cache_config: Optional[CacheConfig] = None,
147
        quant_config: Optional[QuantizationConfig] = None,
148
    ):
149
150
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
151
152
153
154
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
155
        self.attention = GPTNeoXAttention(config, cache_config, quant_config)
156
        self.mlp = GPTNeoXMLP(config, quant_config)
157
158
159

    def forward(
        self,
160
        position_ids: torch.Tensor,
161
        hidden_states: torch.Tensor,
162
163
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
164
165
166
167
168
169
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
170
            attn_metadata=attn_metadata,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


190
@support_torch_compile
191
class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
192

193
194
195
    def __init__(
        self,
        config: GPTNeoXConfig,
196
        cache_config: Optional[CacheConfig] = None,
197
        quant_config: Optional[QuantizationConfig] = None,
198
        prefix: str = "",
199
    ):
200
201
202
        super().__init__()
        self.config = config

203
204
205
206
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
207
208
209
210
211
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
            prefix=f"{prefix}.layers",
        )
212
213
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
214
215
216
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
217
218
219

    def forward(
        self,
220
221
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
222
223
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
224
225
226
227
228
229
230
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_in(input_ids)
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
231
232
233
234
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
235
                kv_caches[i - self.start_layer],
236
                attn_metadata,
237
            )
238
239
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
240
241
242
243
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


244
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
245

246
247
    def __init__(
        self,
248
        config: GPTNeoXConfig,
249
        cache_config: Optional[CacheConfig] = None,
250
        quant_config: Optional[QuantizationConfig] = None,
251
    ):
252
253
        super().__init__()
        self.config = config
254
        self.quant_config = quant_config
255
        self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
256
        self.embed_out = ParallelLMHead(
257
            config.vocab_size,
258
            config.hidden_size,
259
            quant_config=quant_config,
260
        )
261
262
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
263
264
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
265
266
        self.make_empty_intermediate_tensors = (
            self.gpt_neox.make_empty_intermediate_tensors)
267
268
269

    def forward(
        self,
270
271
        input_ids: torch.Tensor,
        positions: torch.Tensor,
272
273
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
274
        intermediate_tensors: Optional[IntermediateTensors] = None,
275
    ) -> Union[torch.Tensor, IntermediateTensors]:
276
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
277
                                      attn_metadata, intermediate_tensors)
278
279
        return hidden_states

280
281
282
283
284
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
285
        logits = self.logits_processor(self.embed_out, hidden_states,
286
287
288
                                       sampling_metadata)
        return logits

289
290
    def sample(
        self,
291
        logits: torch.Tensor,
292
        sampling_metadata: SamplingMetadata,
293
    ) -> Optional[SamplerOutput]:
294
        next_tokens = self.sampler(logits, sampling_metadata)
295
296
        return next_tokens

297
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
298
        params_dict = dict(self.named_parameters())
299
        for name, loaded_weight in weights:
300
            if ("attention.bias" in name or "attention.masked_bias" in name
301
                    or "rotary_emb.inv_freq" in name):
302
                continue
303
304
305
306
307
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
308
309
            if is_pp_missing_parameter(name, self):
                continue
310
311
            param = params_dict[name]

312
            if "query_key_value" in name:
313
314
315
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
316
                # Thus, we need weight conversion.
317
                output_dim = getattr(param, "output_dim", None)
318
                num_heads = self.config.num_attention_heads
319
320
321
322
323
324
325
326
327
328
329
330
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)