opt.py 11.8 KB
Newer Older
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
1
"""1D OPT model compatible with HuggingFace weights."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
3
from typing import Dict, List, Optional, Tuple

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
4
5
6
7
import torch
from torch import nn
from transformers import OPTConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
11
12
from cacheflow.models.utils import (hf_model_weights_iterator,
                                    load_tensor_parallel_weights)
Zhuohan Li's avatar
Zhuohan Li committed
13
14
15
16
17
from cacheflow.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
                                                      ColumnParallelLinear,
                                                      RowParallelLinear)
18
from cacheflow.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
21

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
        # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

    def forward(self, positions: torch.LongTensor):
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
Zhuohan Li's avatar
Zhuohan Li committed
45
46
47
48
49
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
50
        self.scaling = self.head_dim ** -0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
51

52
53
54
        self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
                                             gather_output=False,
                                             perform_initialization=False)
Zhuohan Li's avatar
Zhuohan Li committed
55
56
57
        self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
                                          input_is_parallel=True,
                                          perform_initialization=False)
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
62
63
64
65
66
        self.attn = OPTCacheFlowAttention(scale=self.scaling)

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
67
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
68
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
71
        key_cache, value_cache = kv_cache
        attn_output = self.attn(
            q, k, v, key_cache, value_cache, input_metadata, cache_event)
Zhuohan Li's avatar
Zhuohan Li committed
72
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
73
74
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
75

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
76
77
78
79
class OPTDecoderLayer(nn.Module):

    def __init__(self, config: OPTConfig):
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
80
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
81
82
83
84
85
86
87
88
89
90
91
92
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
        )
        self.do_layer_norm_before = config.do_layer_norm_before
        assert config.activation_function == 'relu'
        self.activation_fn = nn.ReLU()

        self.self_attn_layer_norm = nn.LayerNorm(
            self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
Zhuohan Li's avatar
Zhuohan Li committed
93
94
95
96
97
98
99
100
101
102
        self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
                                        bias=config.enable_bias,
                                        gather_output=False,
                                        perform_initialization=False)
        self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
                                     bias=config.enable_bias,
                                     input_is_parallel=True,
                                     perform_initialization=False)
        self.final_layer_norm = nn.LayerNorm(
            self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
103

Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
107
108
109
110
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
111
112
113
114
115
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
119
120
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
121
122
123
124
125
126
127
128
129
130
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
131
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
132
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
133
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
134
135
136
137
138
139
140
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
141
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
142
143

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
144
145
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
146
147
148
149
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

Zhuohan Li's avatar
Zhuohan Li committed
150
151
152
153
154
155
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.word_embed_proj_dim,
                                                   perform_initialization=False)
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
156

Zhuohan Li's avatar
Zhuohan Li committed
157
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        if config.word_embed_proj_dim != config.hidden_size:
            self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
            self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
        else:
            self.project_in = None

        # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
        # with checkpoints that have been fine-tuned before transformers v4.20.1
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
                config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
            )
        else:
            self.final_layer_norm = None

        self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
186
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
187
188
189
190
191
192
193
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
            inputs_embeds = self.project_in(inputs_embeds)
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
196
197
198
199
200
201
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states, kv_caches[i], input_metadata, cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
202
203
204
205
206
207
208
209

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
210
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
211
212

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
213
        super().__init__()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
214
215
216
217
218
219
        self.decoder = OPTDecoder(config)

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
222
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
223
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
        return self.decoder(
            input_ids, positions, kv_caches, input_metadata, cache_events)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
226
227


Zhuohan Li's avatar
Zhuohan Li committed
228
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
229
230

    def __init__(self, config):
Zhuohan Li's avatar
Zhuohan Li committed
231
232
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
233
        self.model = OPTModel(config)
Zhuohan Li's avatar
Zhuohan Li committed
234
235
236
        # TODO(zhuohan): create a new weight after implementing pipeline
        #                parallelism
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
Woosuk Kwon's avatar
Woosuk Kwon committed
237
        self.sampler = Sampler()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
238
239
240
241
242

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: torch.LongTensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
245
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
246
    ) -> Dict[int, SequenceOutputs]:
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
        hidden_states = self.model(
            input_ids, positions, kv_caches, input_metadata, cache_events)
Woosuk Kwon's avatar
Woosuk Kwon committed
249
        next_tokens = self.sampler(
Zhuohan Li's avatar
Zhuohan Li committed
250
            self.lm_head_weight, hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
251
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
252

253
    _column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
Zhuohan Li's avatar
Zhuohan Li committed
254
255
    _row_parallel_weights = ["out_proj.weight", "fc2.weight"]

256
257
258
    def load_weights(self, model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     use_np_cache: bool = False):
Zhuohan Li's avatar
Zhuohan Li committed
259
260
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()
261
262
263
264

        for name, loaded_weight in hf_model_weights_iterator(
            model_name_or_path, cache_dir, use_np_cache):
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
265
                continue
266
267
268
269
270
271
272
273
274

            if name.startswith("decoder."):
                name = "model." + name

            is_attention_weight = False
            for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
                if att_weight_name not in name:
                    continue
                param = state_dict[name.replace(att_weight_name, "qkv_proj")]
275
                shard_size = param.shape[0] // 3
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                loaded_weight = loaded_weight[
                    shard_size * tensor_model_parallel_rank
                    :shard_size * (tensor_model_parallel_rank + 1)]
                param_slice = param.data[shard_size * stride_id
                                         :shard_size * (stride_id + 1)]
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_attention_weight = True
                break
            if is_attention_weight:
                continue

            param = state_dict[name]
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
                                         self._row_parallel_weights)
292
293
294

    def initialize_dummy_weights(self) -> None:
        for param in self.state_dict().values():
295
            param.data.uniform_(-1e-3, 1e-3)