gpt_neox.py 11.5 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
19
20
21
22
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
23
from typing import List, Optional, Tuple
24
25
26

import torch
from torch import nn
27
28
from transformers import GPTNeoXConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33
34
35
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
                                              load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
36
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.parallel_utils.tensor_parallel import (
38
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
39
from vllm.sequence import SamplerOutput
40
41
42
43
44
45

KVCache = Tuple[torch.Tensor, torch.Tensor]


class GPTNeoXAttention(nn.Module):

46
    def __init__(self, config: GPTNeoXConfig):
47
48
49
50
51
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

52
53
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
54
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
55
56
57
58
59
60
61
62
63
64
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        self.query_key_value = ColumnParallelLinear(
            config.hidden_size,
            3 * config.hidden_size,
            gather_output=False,
            perform_initialization=False)
        self.dense = RowParallelLinear(config.hidden_size,
                                       config.hidden_size,
65
66
67
                                       input_is_parallel=True,
                                       perform_initialization=False)

68
        scaling = self.head_size**-0.5
69
70
        rotary_dim = int(self.head_size * config.rotary_pct)
        assert rotary_dim % 2 == 0
71
72
73
74
75
76
77
78
79
80
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_size,
            scaling,
            rotary_dim,
            base=rope_theta,
            max_position=max_position_embeddings)
81
82
83

    def forward(
        self,
84
        position_ids: torch.Tensor,
85
86
87
88
89
90
91
92
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        k_cache, v_cache = kv_cache
93
94
        attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
95
96
97
98
99
        output, _ = self.dense(attn_output)
        return output


class GPTNeoXMLP(nn.Module):
100

101
    def __init__(self, config: GPTNeoXConfig):
102
103
104
105
106
        super().__init__()
        self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
                                                  config.intermediate_size,
                                                  gather_output=False,
                                                  perform_initialization=False)
107
108
        self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
                                               config.hidden_size,
109
110
                                               input_is_parallel=True,
                                               perform_initialization=False)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        self.act = get_act_fn(config.hidden_act)
112
113
114
115
116
117
118
119
120
121

    def forward(self, hidden_states):
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class GPTNeoXLayer(nn.Module):

122
    def __init__(self, config: GPTNeoXConfig):
123
124
        super().__init__()
        self.use_parallel_residual = config.use_parallel_residual
125
126
127
128
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)
129
130
131
132
133
        self.attention = GPTNeoXAttention(config)
        self.mlp = GPTNeoXMLP(config)

    def forward(
        self,
134
        position_ids: torch.Tensor,
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        attn_input = self.input_layernorm(hidden_states)
        attn_output = self.attention(
            position_ids=position_ids,
            hidden_states=attn_input,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        if self.use_parallel_residual:
            # pseudocode:
            # x = x + attn(ln1(x)) + mlp(ln2(x))
            mlp_input = self.post_attention_layernorm(hidden_states)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output + hidden_states
        else:
            # pseudocode:
            # x = x + attn(ln1(x))
            # x = x + mlp(ln2(x))
            attn_output = attn_output + hidden_states
            mlp_input = self.post_attention_layernorm(attn_output)
            mlp_output = self.mlp(mlp_input)
            hidden_states = mlp_output + attn_output
        return hidden_states


class GPTNeoXModel(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
167

168
    def __init__(self, config: GPTNeoXConfig):
169
170
171
        super().__init__()
        self.config = config

172
173
        self.embed_in = VocabParallelEmbedding(config.vocab_size,
                                               config.hidden_size,
174
                                               perform_initialization=False)
175
176
177
178
        self.layers = nn.ModuleList(
            [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
        self.final_layer_norm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
179
180
181

    def forward(
        self,
182
183
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_in(input_ids)
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


class GPTNeoXForCausalLM(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.gpt_neox = GPTNeoXModel(config)
212
213
214
215
        self.embed_out = ColumnParallelLinear(config.hidden_size,
                                              config.vocab_size,
                                              bias=False,
                                              gather_output=False,
216
                                              perform_initialization=False)
Woosuk Kwon's avatar
Woosuk Kwon committed
217
        self.sampler = Sampler(config.vocab_size)
218
219
220

    def forward(
        self,
221
222
        input_ids: torch.Tensor,
        positions: torch.Tensor,
223
224
225
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
226
    ) -> SamplerOutput:
227
228
229
230
        hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
                                      input_metadata, cache_events)
        next_tokens = self.sampler(self.embed_out.weight, hidden_states,
                                   input_metadata)
231
232
        return next_tokens

233
234
235
236
    _column_parallel_weights = [
        "embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
        "dense_h_to_4h.bias"
    ]
237
238
    _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]

239
240
    def load_weights(self,
                     model_name_or_path: str,
241
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
242
243
                     load_format: str = "auto",
                     revision: Optional[str] = None):
244
245
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()
246
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
247
                model_name_or_path, cache_dir, load_format, revision):
248
            if ("attention.bias" in name or "attention.masked_bias" in name
249
                    or "rotary_emb.inv_freq" in name):
250
251
                continue
            param = state_dict[name]
252
253
            if "query_key_value" in name:
                # NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
                # [num_heads * 3 * head_size, hidden_size], while the
                # required shape is [3 * num_heads * head_size, hidden_size].
256
257
                # Thus, we need weight conversion.
                shard_size = param.shape[0]
258
259
260
                loaded_weight = loaded_weight[
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
261
262
263
264

                num_heads = self.config.num_attention_heads
                hidden_size = self.config.hidden_size
                head_size = hidden_size // num_heads
265
266
267
                if "query_key_value.weight" in name:
                    loaded_weight = loaded_weight.view(-1, 3, head_size,
                                                       hidden_size)
268
                    loaded_weight = loaded_weight.transpose(0, 1)
Woosuk Kwon's avatar
Woosuk Kwon committed
269
                    loaded_weight = loaded_weight.reshape(-1, hidden_size)
270
                elif "query_key_value.bias" in name:
271
272
                    loaded_weight = loaded_weight.view(-1, 3, head_size)
                    loaded_weight = loaded_weight.transpose(0, 1)
Woosuk Kwon's avatar
Woosuk Kwon committed
273
                    loaded_weight = loaded_weight.reshape(-1)
274
                else:
275
276
277
                    raise ValueError(f"Unexpected weight name: {name}")
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
278
279
                                         self._row_parallel_weights,
                                         tensor_model_parallel_rank)