gpt_neox.py 12.6 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
17
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
18
from typing import Iterable, List, Optional, Tuple, Union
19
20
21

import torch
from torch import nn
22
23
from transformers import GPTNeoXConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.compilation.decorators import support_torch_compile
26
from vllm.config import CacheConfig
27
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
35
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
    ParallelLMHead, VocabParallelEmbedding)
38
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
from vllm.sequence import IntermediateTensors
41

42
43
44
45
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

46
47
48

class GPTNeoXAttention(nn.Module):

49
50
51
    def __init__(
        self,
        config: GPTNeoXConfig,
52
        cache_config: Optional[CacheConfig] = None,
53
        quant_config: Optional[QuantizationConfig] = None,
54
    ):
55
56
57
58
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
59
        self.bias = getattr(config, "attention_bias", True)
60

61
62
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
63
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
64
65
66
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

67
        self.query_key_value = QKVParallelLinear(
68
            config.hidden_size,
69
70
            self.head_size,
            self.total_num_heads,
71
            bias=self.bias,
72
            quant_config=quant_config,
73
74
75
76
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
77
            bias=self.bias,
78
            quant_config=quant_config,
79
        )
80
        scaling = self.head_size**-0.5
81
82
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
83
84
85
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        self.rotary_emb = get_rope(
87
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
90
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        )
92
93
94
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
95
96
                              cache_config=cache_config,
                              quant_config=quant_config)
97
98
99

    def forward(
        self,
100
        position_ids: torch.Tensor,
101
        hidden_states: torch.Tensor,
102
103
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
104
105
106
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        q, k = self.rotary_emb(position_ids, q, k)
108
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
109
110
111
112
113
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
114

115
116
117
    def __init__(
        self,
        config: GPTNeoXConfig,
118
        quant_config: Optional[QuantizationConfig] = None,
119
    ):
120
        super().__init__()
121
122
123
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
124
            quant_config=quant_config,
125
126
127
128
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
129
            quant_config=quant_config,
130
        )
131
        self.act = get_act_fn(config.hidden_act)
132
133
134
135
136
137
138
139
140
141

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

142
143
144
    def __init__(
        self,
        config: GPTNeoXConfig,
145
        cache_config: Optional[CacheConfig] = None,
146
        quant_config: Optional[QuantizationConfig] = None,
147
    ):
148
149
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
150
151
152
153
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
154
        self.attention = GPTNeoXAttention(config, cache_config, quant_config)
155
        self.mlp = GPTNeoXMLP(config, quant_config)
156
157
158

    def forward(
        self,
159
        position_ids: torch.Tensor,
160
        hidden_states: torch.Tensor,
161
162
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
163
164
165
166
167
168
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
169
            attn_metadata=attn_metadata,
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


189
@support_torch_compile
190
class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
191

192
193
194
    def __init__(
        self,
        config: GPTNeoXConfig,
195
        cache_config: Optional[CacheConfig] = None,
196
        quant_config: Optional[QuantizationConfig] = None,
197
        prefix: str = "",
198
    ):
199
200
201
        super().__init__()
        self.config = config

202
203
204
205
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
206
207
208
209
210
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
            prefix=f"{prefix}.layers",
        )
211
212
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
213
214
215
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
216
217
218

    def forward(
        self,
219
220
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
221
222
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
223
224
225
226
227
228
229
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_in(input_ids)
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
230
231
232
233
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
234
                kv_caches[i - self.start_layer],
235
                attn_metadata,
236
            )
237
238
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
239
240
241
242
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


243
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
244

245
246
    def __init__(
        self,
247
        config: GPTNeoXConfig,
248
        cache_config: Optional[CacheConfig] = None,
249
        quant_config: Optional[QuantizationConfig] = None,
250
    ):
251
252
        super().__init__()
        self.config = config
253
        self.quant_config = quant_config
254
        self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
255
        self.embed_out = ParallelLMHead(
256
            config.vocab_size,
257
            config.hidden_size,
258
            quant_config=quant_config,
259
        )
260
261
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
262
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
263
        self.sampler = get_sampler()
264
265
        self.make_empty_intermediate_tensors = (
            self.gpt_neox.make_empty_intermediate_tensors)
266
267
268

    def forward(
        self,
269
270
        input_ids: torch.Tensor,
        positions: torch.Tensor,
271
272
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
273
        intermediate_tensors: Optional[IntermediateTensors] = None,
274
    ) -> Union[torch.Tensor, IntermediateTensors]:
275
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
276
                                      attn_metadata, intermediate_tensors)
277
278
        return hidden_states

279
280
281
282
283
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
284
        logits = self.logits_processor(self.embed_out, hidden_states,
285
286
287
                                       sampling_metadata)
        return logits

288
289
    def sample(
        self,
290
        logits: torch.Tensor,
291
        sampling_metadata: SamplingMetadata,
292
    ) -> Optional[SamplerOutput]:
293
        next_tokens = self.sampler(logits, sampling_metadata)
294
295
        return next_tokens

296
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
297
        params_dict = dict(self.named_parameters())
298
        for name, loaded_weight in weights:
299
            if ("attention.bias" in name or "attention.masked_bias" in name
300
                    or "rotary_emb.inv_freq" in name):
301
                continue
302
303
304
305
306
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
307
308
            if is_pp_missing_parameter(name, self):
                continue
309
310
            param = params_dict[name]

311
            if "query_key_value" in name:
312
313
314
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
315
                # Thus, we need weight conversion.
316
                output_dim = getattr(param, "output_dim", None)
317
                num_heads = self.config.num_attention_heads
318
319
320
321
322
323
324
325
326
327
328
329
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)