opt.py 12.7 KB
Newer Older
1
2
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
3
# Copyright 2023 The CacheFlow team.
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
17
18
19
20
21
"""Inference-only OPT model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from typing import Dict, List, Optional, Tuple

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
24
25
26
27
import torch
from torch import nn
from transformers import OPTConfig

28
29
30
31
32
33
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
                                                   load_tensor_parallel_weights)
from cacheflow.model_executor.parallel_utils.parallel_state import (
Zhuohan Li's avatar
Zhuohan Li committed
34
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
35
36
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
37
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
41
42
43
44
45
46
47
48
49

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
        # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

50
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
51
52
53
54
55
56
57
58
59
60
61
62
63
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
Zhuohan Li's avatar
Zhuohan Li committed
64
65
66
67
68
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
69
        self.scaling = self.head_dim ** -0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
70

71
72
73
        self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
                                             gather_output=False,
                                             perform_initialization=False)
Zhuohan Li's avatar
Zhuohan Li committed
74
75
76
        self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
                                          input_is_parallel=True,
                                          perform_initialization=False)
77
78
        self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
                                          scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
83
84
85
86

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
87
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
        key_cache, value_cache = kv_cache
        attn_output = self.attn(
            q, k, v, key_cache, value_cache, input_metadata, cache_event)
Zhuohan Li's avatar
Zhuohan Li committed
92
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
93
94
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
95

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
96
97
98
99
class OPTDecoderLayer(nn.Module):

    def __init__(self, config: OPTConfig):
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
100
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
101
102
103
104
105
106
107
108
109
110
111
112
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
        )
        self.do_layer_norm_before = config.do_layer_norm_before
        assert config.activation_function == 'relu'
        self.activation_fn = nn.ReLU()

        self.self_attn_layer_norm = nn.LayerNorm(
            self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
Zhuohan Li's avatar
Zhuohan Li committed
113
114
115
116
117
118
119
120
121
122
        self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
                                        bias=config.enable_bias,
                                        gather_output=False,
                                        perform_initialization=False)
        self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
                                     bias=config.enable_bias,
                                     input_is_parallel=True,
                                     perform_initialization=False)
        self.final_layer_norm = nn.LayerNorm(
            self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
123

Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
127
128
129
130
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
131
132
133
134
135
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
141
142
143
144
145
146
147
148
149
150
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
151
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
152
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
153
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
154
155
156
157
158
159
160
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
161
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
162
163

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
164
165
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
166
167
168
169
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

Zhuohan Li's avatar
Zhuohan Li committed
170
171
172
173
174
175
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.word_embed_proj_dim,
                                                   perform_initialization=False)
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
176

Zhuohan Li's avatar
Zhuohan Li committed
177
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        if config.word_embed_proj_dim != config.hidden_size:
            self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
            self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
        else:
            self.project_in = None

        # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
        # with checkpoints that have been fine-tuned before transformers v4.20.1
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
                config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
            )
        else:
            self.final_layer_norm = None

        self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
202
203
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
207
208
209
210
211
212
213
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
            inputs_embeds = self.project_in(inputs_embeds)
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
216
217
218
219
220
221
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states, kv_caches[i], input_metadata, cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
222
223
224
225
226
227
228
229

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
230
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
231
232

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
233
        super().__init__()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
234
235
236
237
        self.decoder = OPTDecoder(config)

    def forward(
        self,
238
239
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
243
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
        return self.decoder(
            input_ids, positions, kv_caches, input_metadata, cache_events)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
246
247


Zhuohan Li's avatar
Zhuohan Li committed
248
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
249
250

    def __init__(self, config):
Zhuohan Li's avatar
Zhuohan Li committed
251
252
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
253
        self.model = OPTModel(config)
Zhuohan Li's avatar
Zhuohan Li committed
254
255
256
        # TODO(zhuohan): create a new weight after implementing pipeline
        #                parallelism
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
Woosuk Kwon's avatar
Woosuk Kwon committed
257
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
258
259
260

    def forward(
        self,
261
262
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
266
    ) -> Dict[int, SequenceOutputs]:
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
        hidden_states = self.model(
            input_ids, positions, kv_caches, input_metadata, cache_events)
Woosuk Kwon's avatar
Woosuk Kwon committed
269
        next_tokens = self.sampler(
Zhuohan Li's avatar
Zhuohan Li committed
270
            self.lm_head_weight, hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
271
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
272

273
    _column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
Zhuohan Li's avatar
Zhuohan Li committed
274
275
    _row_parallel_weights = ["out_proj.weight", "fc2.weight"]

276
277
278
    def load_weights(self, model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     use_np_cache: bool = False):
Zhuohan Li's avatar
Zhuohan Li committed
279
280
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()
281
282
283
284

        for name, loaded_weight in hf_model_weights_iterator(
            model_name_or_path, cache_dir, use_np_cache):
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
285
                continue
286
287
288
289
290
291
292
293
294

            if name.startswith("decoder."):
                name = "model." + name

            is_attention_weight = False
            for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
                if att_weight_name not in name:
                    continue
                param = state_dict[name.replace(att_weight_name, "qkv_proj")]
295
                shard_size = param.shape[0] // 3
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
                loaded_weight = loaded_weight[
                    shard_size * tensor_model_parallel_rank
                    :shard_size * (tensor_model_parallel_rank + 1)]
                param_slice = param.data[shard_size * stride_id
                                         :shard_size * (stride_id + 1)]
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_attention_weight = True
                break
            if is_attention_weight:
                continue

            param = state_dict[name]
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
311
312
                                         self._row_parallel_weights,
                                         tensor_model_parallel_rank)