"vscode:/vscode.git/clone" did not exist on "66914f7b19ff55ea29114aa229c6b94ffc9e6a35"
opt.py 13 KB
Newer Older
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
1
"""1D OPT model compatible with HuggingFace weights."""
Zhuohan Li's avatar
Zhuohan Li committed
2
3
4
5
import os
import glob
import filelock
from tqdm import tqdm
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from typing import Dict, List, Optional, Tuple

Zhuohan Li's avatar
Zhuohan Li committed
8
import numpy as np
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
9
10
11
import torch
from torch import nn
from transformers import OPTConfig
Zhuohan Li's avatar
Zhuohan Li committed
12
from huggingface_hub import snapshot_download
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
13

Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
Zhuohan Li's avatar
Zhuohan Li committed
17
18
19
20
21
from cacheflow.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
                                                      ColumnParallelLinear,
                                                      RowParallelLinear)
22
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
        # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

    def forward(self, positions: torch.LongTensor):
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
Zhuohan Li's avatar
Zhuohan Li committed
49
50
51
52
53
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
54
        self.scaling = self.head_dim ** -0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55

56
57
58
        self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
                                             gather_output=False,
                                             perform_initialization=False)
Zhuohan Li's avatar
Zhuohan Li committed
59
60
61
        self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
                                          input_is_parallel=True,
                                          perform_initialization=False)
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
66
67
68
69
70
        self.attn = OPTCacheFlowAttention(scale=self.scaling)

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
71
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
75
        key_cache, value_cache = kv_cache
        attn_output = self.attn(
            q, k, v, key_cache, value_cache, input_metadata, cache_event)
Zhuohan Li's avatar
Zhuohan Li committed
76
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
77
78
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
79

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
80
81
82
83
class OPTDecoderLayer(nn.Module):

    def __init__(self, config: OPTConfig):
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
84
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
85
86
87
88
89
90
91
92
93
94
95
96
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
        )
        self.do_layer_norm_before = config.do_layer_norm_before
        assert config.activation_function == 'relu'
        self.activation_fn = nn.ReLU()

        self.self_attn_layer_norm = nn.LayerNorm(
            self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
Zhuohan Li's avatar
Zhuohan Li committed
97
98
99
100
101
102
103
104
105
106
        self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
                                        bias=config.enable_bias,
                                        gather_output=False,
                                        perform_initialization=False)
        self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
                                     bias=config.enable_bias,
                                     input_is_parallel=True,
                                     perform_initialization=False)
        self.final_layer_norm = nn.LayerNorm(
            self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
107

Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
113
114
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
115
116
117
118
119
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
122
123
124
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
125
126
127
128
129
130
131
132
133
134
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
135
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
136
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
137
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
138
139
140
141
142
143
144
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
145
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
146
147

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
148
149
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
150
151
152
153
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

Zhuohan Li's avatar
Zhuohan Li committed
154
155
156
157
158
159
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.word_embed_proj_dim,
                                                   perform_initialization=False)
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
160

Zhuohan Li's avatar
Zhuohan Li committed
161
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        if config.word_embed_proj_dim != config.hidden_size:
            self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
            self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
        else:
            self.project_in = None

        # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
        # with checkpoints that have been fine-tuned before transformers v4.20.1
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
                config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
            )
        else:
            self.final_layer_norm = None

        self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
191
192
193
194
195
196
197
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
            inputs_embeds = self.project_in(inputs_embeds)
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
201
202
203
204
205
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states, kv_caches[i], input_metadata, cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
206
207
208
209
210
211
212
213

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
214
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
215
216

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
217
        super().__init__()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
218
219
220
221
222
223
        self.decoder = OPTDecoder(config)

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
226
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
227
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
        return self.decoder(
            input_ids, positions, kv_caches, input_metadata, cache_events)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
230
231


Zhuohan Li's avatar
Zhuohan Li committed
232
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
233
234

    def __init__(self, config):
Zhuohan Li's avatar
Zhuohan Li committed
235
236
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
237
        self.model = OPTModel(config)
Zhuohan Li's avatar
Zhuohan Li committed
238
239
240
        # TODO(zhuohan): create a new weight after implementing pipeline
        #                parallelism
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
Woosuk Kwon's avatar
Woosuk Kwon committed
241
        self.sampler = Sampler()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
242
243
244
245
246

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
250
    ) -> Dict[int, SequenceOutputs]:
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
        hidden_states = self.model(
            input_ids, positions, kv_caches, input_metadata, cache_events)
Woosuk Kwon's avatar
Woosuk Kwon committed
253
        next_tokens = self.sampler(
Zhuohan Li's avatar
Zhuohan Li committed
254
            self.lm_head_weight, hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
255
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
256

257
    _column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
Zhuohan Li's avatar
Zhuohan Li committed
258
259
260
261
262
263
264
265
    _row_parallel_weights = ["out_proj.weight", "fc2.weight"]

    def load_weights(self, weights_path: str):
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()
        for name, param in state_dict.items():
            if "lm_head_weight" in name:
                continue
266
267
268
269
270
271
272
            if "qkv_proj" in name:
                shard_size = param.shape[0] // 3
                weights_to_concat = []
                for weight_name in ["q_proj", "k_proj", "v_proj"]:
                    weight = np.load(os.path.join(
                        weights_path, name.replace("qkv_proj", weight_name)))
                    weights_to_concat.append(weight[
Zhuohan Li's avatar
Zhuohan Li committed
273
                        shard_size * tensor_model_parallel_rank
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
                        :shard_size * (tensor_model_parallel_rank + 1)])
                loaded_weight = torch.from_numpy(
                    np.concatenate(weights_to_concat, axis=0))
            else:
                loaded_weight = torch.from_numpy(
                    np.load(os.path.join(weights_path, name)))
                for p in self._column_parallel_weights:
                    if p in name:
                        shard_size = param.shape[0]
                        loaded_weight = loaded_weight[
                            shard_size * tensor_model_parallel_rank
                            :shard_size * (tensor_model_parallel_rank + 1)]
                        break
                for p in self._row_parallel_weights:
                    if p in name:
                        shard_size = param.shape[1]
                        loaded_weight = loaded_weight[
                            :,
                            shard_size * tensor_model_parallel_rank
                            :shard_size * (tensor_model_parallel_rank + 1)]
                        break
Zhuohan Li's avatar
Zhuohan Li committed
295
296
297
298
299

            assert param.shape == loaded_weight.shape
            param.data.copy_(loaded_weight)

    @staticmethod
Woosuk Kwon's avatar
Woosuk Kwon committed
300
    def get_weights(model_name: str, path: str):
Zhuohan Li's avatar
Zhuohan Li committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        path = os.path.join(path, f"{model_name}-np")
        path = os.path.abspath(os.path.expanduser(path))
        os.makedirs(path, exist_ok=True)
        lock_path = os.path.join(path, "file_lock")
        lock = filelock.FileLock(lock_path)

        with lock:
            test_weight_path = os.path.join(
                path, "model.decoder.embed_positions.weight")
            if os.path.exists(test_weight_path):
                return path

            folder = snapshot_download(model_name, allow_patterns="*.bin",
                                       cache_dir=os.path.join(path, "cache"))
            bin_files = glob.glob(os.path.join(folder, "*.bin"))

            for bin_file in tqdm(bin_files, desc="Convert format"):
Woosuk Kwon's avatar
Woosuk Kwon committed
318
                state = torch.load(bin_file, map_location="cpu")
Zhuohan Li's avatar
Zhuohan Li committed
319
320
321
322
323
324
325
326
                for name, param in tqdm(state.items(), leave=False):
                    if name.startswith("decoder."):
                        name = "model." + name
                    param_path = os.path.join(path, name)
                    with open(param_path, "wb") as f:
                        np.save(f, param.cpu().detach().numpy())

            return path