gpt_neox.py 13.7 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
17
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
18
from typing import Iterable, List, Optional, Set, Tuple, Union
19
20
21

import torch
from torch import nn
22
23
from transformers import GPTNeoXConfig

24
from vllm.attention import Attention, AttentionMetadata
25
from vllm.compilation.decorators import support_torch_compile
26
from vllm.config import CacheConfig, VllmConfig
27
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
35
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
    ParallelLMHead, VocabParallelEmbedding)
38
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
from vllm.sequence import IntermediateTensors
41

42
43
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
44
45
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
46

47
48
49

class GPTNeoXAttention(nn.Module):

50
51
52
    def __init__(
        self,
        config: GPTNeoXConfig,
53
        cache_config: Optional[CacheConfig] = None,
54
        quant_config: Optional[QuantizationConfig] = None,
55
        prefix: str = "",
56
    ):
57
58
59
60
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
61
        self.bias = getattr(config, "attention_bias", True)
62

63
64
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
65
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
66
67
68
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

69
        self.query_key_value = QKVParallelLinear(
70
            config.hidden_size,
71
72
            self.head_size,
            self.total_num_heads,
73
            bias=self.bias,
74
            quant_config=quant_config,
75
76
77
78
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
79
            bias=self.bias,
80
            quant_config=quant_config,
81
        )
82
        scaling = self.head_size**-0.5
83
84
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
85
86
87
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        self.rotary_emb = get_rope(
89
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
92
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        )
94
95
96
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
97
                              cache_config=cache_config,
98
99
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
100
101
102

    def forward(
        self,
103
        position_ids: torch.Tensor,
104
        hidden_states: torch.Tensor,
105
106
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
107
108
109
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
110
        q, k = self.rotary_emb(position_ids, q, k)
111
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
112
113
114
115
116
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
117

118
119
120
    def __init__(
        self,
        config: GPTNeoXConfig,
121
        quant_config: Optional[QuantizationConfig] = None,
122
    ):
123
        super().__init__()
124
125
126
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
127
            quant_config=quant_config,
128
129
130
131
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
132
            quant_config=quant_config,
133
        )
134
        self.act = get_act_fn(config.hidden_act)
135
136
137
138
139
140
141
142
143
144

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

145
146
147
    def __init__(
        self,
        config: GPTNeoXConfig,
148
        cache_config: Optional[CacheConfig] = None,
149
        quant_config: Optional[QuantizationConfig] = None,
150
        prefix: str = "",
151
    ):
152
153
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
154
155
156
157
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
158
159
160
161
        self.attention = GPTNeoXAttention(config,
                                          cache_config,
                                          quant_config,
                                          prefix=f"{prefix}.attention")
162
        self.mlp = GPTNeoXMLP(config, quant_config)
163
164
165

    def forward(
        self,
166
        position_ids: torch.Tensor,
167
        hidden_states: torch.Tensor,
168
169
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
170
171
172
173
174
175
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
176
            attn_metadata=attn_metadata,
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


196
@support_torch_compile
197
class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
198

199
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
200
        super().__init__()
201
202
203
204
205

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

206
207
        self.config = config

208
209
210
211
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
212
213
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
214
215
            lambda prefix: GPTNeoXLayer(
                config, cache_config, quant_config, prefix=prefix),
216
217
            prefix=f"{prefix}.layers",
        )
218
219
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
220
221
222
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
223

224
225
226
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_in(input_ids)

227
228
    def forward(
        self,
229
230
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
231
232
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
233
        intermediate_tensors: Optional[IntermediateTensors],
234
        inputs_embeds: Optional[torch.Tensor] = None,
235
236
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
237
238
239
240
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
241
242
243
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
244
245
246
247
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
248
                kv_caches[i - self.start_layer],
249
                attn_metadata,
250
            )
251
252
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
253
254
255
256
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


257
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
258

259
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
260
        super().__init__()
261
262
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
263
        self.config = config
264
        self.quant_config = quant_config
265
266
        self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "gpt_neox"))
267
        self.embed_out = ParallelLMHead(
268
            config.vocab_size,
269
            config.hidden_size,
270
            quant_config=quant_config,
271
        )
272
273
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
274
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
275
        self.sampler = get_sampler()
276
277
        self.make_empty_intermediate_tensors = (
            self.gpt_neox.make_empty_intermediate_tensors)
278

279
280
281
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.gpt_neox.get_input_embeddings(input_ids)

282
283
    def forward(
        self,
284
285
        input_ids: torch.Tensor,
        positions: torch.Tensor,
286
287
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
288
        intermediate_tensors: Optional[IntermediateTensors] = None,
289
        inputs_embeds: Optional[torch.Tensor] = None,
290
    ) -> Union[torch.Tensor, IntermediateTensors]:
291
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
292
293
                                      attn_metadata, intermediate_tensors,
                                      inputs_embeds)
294
295
        return hidden_states

296
297
298
299
300
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
301
        logits = self.logits_processor(self.embed_out, hidden_states,
302
303
304
                                       sampling_metadata)
        return logits

305
306
    def sample(
        self,
307
        logits: torch.Tensor,
308
        sampling_metadata: SamplingMetadata,
309
    ) -> Optional[SamplerOutput]:
310
        next_tokens = self.sampler(logits, sampling_metadata)
311
312
        return next_tokens

313
314
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
315
        params_dict = dict(self.named_parameters())
316
        loaded_params: Set[str] = set()
317
        for name, loaded_weight in weights:
318
            if ("attention.bias" in name or "attention.masked_bias" in name
319
                    or "rotary_emb.inv_freq" in name):
320
                continue
321
322
323
324
325
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
326
327
            if is_pp_missing_parameter(name, self):
                continue
328
329
            param = params_dict[name]

330
            if "query_key_value" in name:
331
332
333
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
334
                # Thus, we need weight conversion.
335
                output_dim = getattr(param, "output_dim", None)
336
                num_heads = self.config.num_attention_heads
337
338
339
340
341
342
343
344
345
346
347
348
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
349
350
            loaded_params.add(name)
        return loaded_params