opt.py 13.4 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
20
21
22
23
"""Inference-only OPT model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
24
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
26
27
28
29
import torch
from torch import nn
from transformers import OPTConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33
34
35
36
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
                                              load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
Zhuohan Li's avatar
Zhuohan Li committed
37
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.parallel_utils.tensor_parallel import (
39
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
40
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
44
45
46
47

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
48
49
50
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
51
52
53
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

54
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
64
65
66
67
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
68
69
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72
73
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
74
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
75

76
77
78
        self.qkv_proj = ColumnParallelLinear(embed_dim,
                                             3 * embed_dim,
                                             bias=bias,
79
80
                                             gather_output=False,
                                             perform_initialization=False)
81
82
83
        self.out_proj = RowParallelLinear(embed_dim,
                                          embed_dim,
                                          bias=bias,
Zhuohan Li's avatar
Zhuohan Li committed
84
85
                                          input_is_parallel=True,
                                          perform_initialization=False)
86
87
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
                                   scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92
93
94
95
96

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
97
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        key_cache, value_cache = kv_cache
100
101
        attn_output = self.attn(q, k, v, key_cache, value_cache,
                                input_metadata, cache_event)
Zhuohan Li's avatar
Zhuohan Li committed
102
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
103
104
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
105

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
106
107
108
109
class OPTDecoderLayer(nn.Module):

    def __init__(self, config: OPTConfig):
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
110
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
111
112
113
114
115
116
117
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
        )
        self.do_layer_norm_before = config.do_layer_norm_before
Woosuk Kwon's avatar
Woosuk Kwon committed
118
        self.activation_fn = get_act_fn(config.activation_function)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
119
120

        self.self_attn_layer_norm = nn.LayerNorm(
121
122
123
124
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
        self.fc1 = ColumnParallelLinear(self.embed_dim,
                                        config.ffn_dim,
Zhuohan Li's avatar
Zhuohan Li committed
125
126
127
                                        bias=config.enable_bias,
                                        gather_output=False,
                                        perform_initialization=False)
128
129
        self.fc2 = RowParallelLinear(config.ffn_dim,
                                     self.embed_dim,
Zhuohan Li's avatar
Zhuohan Li committed
130
131
132
133
                                     bias=config.enable_bias,
                                     input_is_parallel=True,
                                     perform_initialization=False)
        self.final_layer_norm = nn.LayerNorm(
134
135
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
136

Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
141
142
143
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
144
145
146
147
148
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
149
150
151
152
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       input_metadata=input_metadata,
                                       cache_event=cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
153
154
155
156
157
158
159
160
161
162
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
163
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
164
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
165
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
166
167
168
169
170
171
172
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
173
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
174
175

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
176
177
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
178
179
180
181
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

182
183
184
185
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
            perform_initialization=False)
Zhuohan Li's avatar
Zhuohan Li committed
186
187
188
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
189

Zhuohan Li's avatar
Zhuohan Li committed
190
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
191
        if config.word_embed_proj_dim != config.hidden_size:
192
193
194
            self.project_out = nn.Linear(config.hidden_size,
                                         config.word_embed_proj_dim,
                                         bias=False)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
195
196
197
198
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
199
200
201
            self.project_in = nn.Linear(config.word_embed_proj_dim,
                                        config.hidden_size,
                                        bias=False)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
202
203
204
        else:
            self.project_in = None

205
206
207
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
208
209
210
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
211
212
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
213
214
215
        else:
            self.final_layer_norm = None

216
217
        self.layers = nn.ModuleList(
            [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
218
219
220

    def forward(
        self,
221
222
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
226
227
228
229
230
231
232
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
            inputs_embeds = self.project_in(inputs_embeds)
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
235
236
237
238
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
239
240
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
                                  cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
241
242
243
244
245
246
247
248

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
249
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
250
251

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
252
        super().__init__()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
253
254
255
256
        self.decoder = OPTDecoder(config)

    def forward(
        self,
257
258
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
262
    ) -> torch.Tensor:
263
264
        return self.decoder(input_ids, positions, kv_caches, input_metadata,
                            cache_events)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
265
266


Zhuohan Li's avatar
Zhuohan Li committed
267
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
268
269

    def __init__(self, config):
Zhuohan Li's avatar
Zhuohan Li committed
270
271
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
272
        self.model = OPTModel(config)
Zhuohan Li's avatar
Zhuohan Li committed
273
274
275
        # TODO(zhuohan): create a new weight after implementing pipeline
        #                parallelism
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
Woosuk Kwon's avatar
Woosuk Kwon committed
276
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
277
278
279

    def forward(
        self,
280
281
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
282
283
284
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
285
    ) -> SamplerOutput:
286
287
288
289
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
290
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
291

292
293
294
    _column_parallel_weights = [
        "embed_tokens.weight", "fc1.weight", "fc1.bias"
    ]
Zhuohan Li's avatar
Zhuohan Li committed
295
296
    _row_parallel_weights = ["out_proj.weight", "fc2.weight"]

297
298
    def load_weights(self,
                     model_name_or_path: str,
299
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
300
301
                     load_format: str = "auto",
                     revision: Optional[str] = None):
Zhuohan Li's avatar
Zhuohan Li committed
302
303
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()
304
305

        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
306
                model_name_or_path, cache_dir, load_format, revision):
307
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
308
                continue
309
310
311
312
313

            if name.startswith("decoder."):
                name = "model." + name

            is_attention_weight = False
314
315
            for stride_id, att_weight_name in enumerate(
                ["q_proj", "k_proj", "v_proj"]):
316
317
318
                if att_weight_name not in name:
                    continue
                param = state_dict[name.replace(att_weight_name, "qkv_proj")]
319
                shard_size = param.shape[0] // 3
320
                loaded_weight = loaded_weight[
321
322
323
324
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
                param_slice = param.data[shard_size * stride_id:shard_size *
                                         (stride_id + 1)]
325
326
327
328
329
330
331
332
333
334
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_attention_weight = True
                break
            if is_attention_weight:
                continue

            param = state_dict[name]
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
335
336
                                         self._row_parallel_weights,
                                         tensor_model_parallel_rank)