gpt_neox.py 13.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Set, Tuple, Union
21
22
23

import torch
from torch import nn
24
25
from transformers import GPTNeoXConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.compilation.decorators import support_torch_compile
28
from vllm.config import CacheConfig, VllmConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
37
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
43

44
45
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
46
47
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
48

49
50
51

class GPTNeoXAttention(nn.Module):

52
53
54
    def __init__(
        self,
        config: GPTNeoXConfig,
55
        cache_config: Optional[CacheConfig] = None,
56
        quant_config: Optional[QuantizationConfig] = None,
57
        prefix: str = "",
58
    ):
59
60
61
62
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads
63
        self.bias = getattr(config, "attention_bias", True)
64

65
66
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
67
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
68
69
70
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

71
        self.query_key_value = QKVParallelLinear(
72
            config.hidden_size,
73
74
            self.head_size,
            self.total_num_heads,
75
            bias=self.bias,
76
            quant_config=quant_config,
77
78
79
80
        )
        self.dense = RowParallelLinear(
            config.hidden_size,
            config.hidden_size,
81
            bias=self.bias,
82
            quant_config=quant_config,
83
        )
84
        scaling = self.head_size**-0.5
85
86
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
87
88
89
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
90
        self.rotary_emb = get_rope(
91
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
94
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        )
96
97
98
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
99
                              cache_config=cache_config,
100
101
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
102
103
104

    def forward(
        self,
105
        position_ids: torch.Tensor,
106
        hidden_states: torch.Tensor,
107
108
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
109
110
111
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        q, k = self.rotary_emb(position_ids, q, k)
113
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
114
115
116
117
118
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
119

120
121
122
    def __init__(
        self,
        config: GPTNeoXConfig,
123
        quant_config: Optional[QuantizationConfig] = None,
124
    ):
125
        super().__init__()
126
127
128
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
129
            quant_config=quant_config,
130
131
132
133
        )
        self.dense_4h_to_h = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
134
            quant_config=quant_config,
135
        )
136
        self.act = get_act_fn(config.hidden_act)
137
138
139
140
141
142
143
144
145
146

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

147
148
149
    def __init__(
        self,
        config: GPTNeoXConfig,
150
        cache_config: Optional[CacheConfig] = None,
151
        quant_config: Optional[QuantizationConfig] = None,
152
        prefix: str = "",
153
    ):
154
155
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
156
157
158
159
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
160
161
162
163
        self.attention = GPTNeoXAttention(config,
                                          cache_config,
                                          quant_config,
                                          prefix=f"{prefix}.attention")
164
        self.mlp = GPTNeoXMLP(config, quant_config)
165
166
167

    def forward(
        self,
168
        position_ids: torch.Tensor,
169
        hidden_states: torch.Tensor,
170
171
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
172
173
174
175
176
177
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
178
            attn_metadata=attn_metadata,
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


198
@support_torch_compile
199
class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
200

201
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
202
        super().__init__()
203
204
205
206
207

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

208
209
        self.config = config

210
211
212
213
        self.embed_in = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
214
215
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
216
217
            lambda prefix: GPTNeoXLayer(
                config, cache_config, quant_config, prefix=prefix),
218
219
            prefix=f"{prefix}.layers",
        )
220
221
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
222
223
224
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
225

226
227
228
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_in(input_ids)

229
230
    def forward(
        self,
231
232
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
233
234
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
235
        intermediate_tensors: Optional[IntermediateTensors],
236
        inputs_embeds: Optional[torch.Tensor] = None,
237
238
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
239
240
241
242
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
243
244
245
        else:
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
246
247
248
249
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
250
                kv_caches[i - self.start_layer],
251
                attn_metadata,
252
            )
253
254
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
255
256
257
258
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


259
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
260

261
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
262
        super().__init__()
263
264
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
265
        self.config = config
266
        self.quant_config = quant_config
267
268
        self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "gpt_neox"))
269
        self.embed_out = ParallelLMHead(
270
            config.vocab_size,
271
            config.hidden_size,
272
            quant_config=quant_config,
273
        )
274
275
        if self.config.tie_word_embeddings:
            self.embed_out.weight = self.gpt_neox.embed_in.weight
276
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
277
        self.sampler = get_sampler()
278
279
        self.make_empty_intermediate_tensors = (
            self.gpt_neox.make_empty_intermediate_tensors)
280

281
282
283
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.gpt_neox.get_input_embeddings(input_ids)

284
285
    def forward(
        self,
286
287
        input_ids: torch.Tensor,
        positions: torch.Tensor,
288
289
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
290
        intermediate_tensors: Optional[IntermediateTensors] = None,
291
        inputs_embeds: Optional[torch.Tensor] = None,
292
    ) -> Union[torch.Tensor, IntermediateTensors]:
293
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
294
295
                                      attn_metadata, intermediate_tensors,
                                      inputs_embeds)
296
297
        return hidden_states

298
299
300
301
302
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
303
        logits = self.logits_processor(self.embed_out, hidden_states,
304
305
306
                                       sampling_metadata)
        return logits

307
308
    def sample(
        self,
309
        logits: torch.Tensor,
310
        sampling_metadata: SamplingMetadata,
311
    ) -> Optional[SamplerOutput]:
312
        next_tokens = self.sampler(logits, sampling_metadata)
313
314
        return next_tokens

315
316
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
317
        params_dict = dict(self.named_parameters())
318
        loaded_params: Set[str] = set()
319
        for name, loaded_weight in weights:
320
            if ("attention.bias" in name or "attention.masked_bias" in name
321
                    or "rotary_emb.inv_freq" in name):
322
                continue
323
324
325
326
327
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using OpenRLHF may include
                # these tensors in the checkpoint. Skip them.
                continue
328
329
            if is_pp_missing_parameter(name, self):
                continue
330
331
            param = params_dict[name]

332
            if "query_key_value" in name:
333
334
335
                # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
336
                # Thus, we need weight conversion.
337
                output_dim = getattr(param, "output_dim", None)
338
                num_heads = self.config.num_attention_heads
339
340
341
342
343
344
345
346
347
348
349
350
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
351
352
            loaded_params.add(name)
        return loaded_params