gpt_neox.py 12.7 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
17
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
18
from typing import Iterable, List, Optional, Tuple, Union
19
20
21

import torch
from torch import nn
22
23
from transformers import GPTNeoXConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.compilation.decorators import support_torch_compile
26
from vllm.config import CacheConfig, VllmConfig
27
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
35
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
    ParallelLMHead, VocabParallelEmbedding)
38
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
from vllm.sequence import IntermediateTensors
41

42
43
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
44
45
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
46

47
48
49

class GPTNeoXAttention(nn.Module):

50
51
52
    def __init__(
        self,
        config: GPTNeoXConfig,
53
        cache_config: Optional[CacheConfig] = None,
54
        quant_config: Optional[QuantizationConfig] = None,
55
    ):
56
57
58
59
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
60
        self.bias = getattr(config, "attention_bias", True)
61

62
63
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
64
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
65
66
67
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

68
        self.query_key_value = QKVParallelLinear(
69
            config.hidden_size,
70
71
            self.head_size,
            self.total_num_heads,
72
            bias=self.bias,
73
            quant_config=quant_config,
74
75
76
77
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
78
            bias=self.bias,
79
            quant_config=quant_config,
80
        )
81
        scaling = self.head_size**-0.5
82
83
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
84
85
86
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        self.rotary_emb = get_rope(
88
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
91
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        )
93
94
95
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
96
97
                              cache_config=cache_config,
                              quant_config=quant_config)
98
99
100

    def forward(
        self,
101
        position_ids: torch.Tensor,
102
        hidden_states: torch.Tensor,
103
104
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
105
106
107
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        q, k = self.rotary_emb(position_ids, q, k)
109
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
110
111
112
113
114
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
115

116
117
118
    def __init__(
        self,
        config: GPTNeoXConfig,
119
        quant_config: Optional[QuantizationConfig] = None,
120
    ):
121
        super().__init__()
122
123
124
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
125
            quant_config=quant_config,
126
127
128
129
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
130
            quant_config=quant_config,
131
        )
132
        self.act = get_act_fn(config.hidden_act)
133
134
135
136
137
138
139
140
141
142

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

143
144
145
    def __init__(
        self,
        config: GPTNeoXConfig,
146
        cache_config: Optional[CacheConfig] = None,
147
        quant_config: Optional[QuantizationConfig] = None,
148
    ):
149
150
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
151
152
153
154
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
155
        self.attention = GPTNeoXAttention(config, cache_config, quant_config)
156
        self.mlp = GPTNeoXMLP(config, quant_config)
157
158
159

    def forward(
        self,
160
        position_ids: torch.Tensor,
161
        hidden_states: torch.Tensor,
162
163
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
164
165
166
167
168
169
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
170
            attn_metadata=attn_metadata,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


190
@support_torch_compile
191
class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
192

193
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
194
        super().__init__()
195
196
197
198
199

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

200
201
        self.config = config

202
203
204
205
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
206
207
208
209
210
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
            prefix=f"{prefix}.layers",
        )
211
212
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
213
214
215
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
216
217
218

    def forward(
        self,
219
220
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
221
222
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
223
224
225
226
227
228
229
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_in(input_ids)
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
230
231
232
233
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
234
                kv_caches[i - self.start_layer],
235
                attn_metadata,
236
            )
237
238
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
239
240
241
242
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


243
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
244

245
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
246
        super().__init__()
247
248
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
249
        self.config = config
250
        self.quant_config = quant_config
251
252
        self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "gpt_neox"))
253
        self.embed_out = ParallelLMHead(
254
            config.vocab_size,
255
            config.hidden_size,
256
            quant_config=quant_config,
257
        )
258
259
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
260
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
261
        self.sampler = get_sampler()
262
263
        self.make_empty_intermediate_tensors = (
            self.gpt_neox.make_empty_intermediate_tensors)
264
265
266

    def forward(
        self,
267
268
        input_ids: torch.Tensor,
        positions: torch.Tensor,
269
270
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
271
        intermediate_tensors: Optional[IntermediateTensors] = None,
272
    ) -> Union[torch.Tensor, IntermediateTensors]:
273
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
274
                                      attn_metadata, intermediate_tensors)
275
276
        return hidden_states

277
278
279
280
281
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
282
        logits = self.logits_processor(self.embed_out, hidden_states,
283
284
285
                                       sampling_metadata)
        return logits

286
287
    def sample(
        self,
288
        logits: torch.Tensor,
289
        sampling_metadata: SamplingMetadata,
290
    ) -> Optional[SamplerOutput]:
291
        next_tokens = self.sampler(logits, sampling_metadata)
292
293
        return next_tokens

294
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
295
        params_dict = dict(self.named_parameters())
296
        for name, loaded_weight in weights:
297
            if ("attention.bias" in name or "attention.masked_bias" in name
298
                    or "rotary_emb.inv_freq" in name):
299
                continue
300
301
302
303
304
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
305
306
            if is_pp_missing_parameter(name, self):
                continue
307
308
            param = params_dict[name]

309
            if "query_key_value" in name:
310
311
312
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
313
                # Thus, we need weight conversion.
314
                output_dim = getattr(param, "output_dim", None)
315
                num_heads = self.config.num_attention_heads
316
317
318
319
320
321
322
323
324
325
326
327
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)