opt.py 13 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
20
21
22
23
"""Inference-only OPT model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
24
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
26
27
28
29
import torch
from torch import nn
from transformers import OPTConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
33
34
35
36
37
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.sampler import Sampler
39
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
41
from vllm.model_executor.parallel_utils.parallel_state import (
42
43
44
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
45
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
49
50
51
52

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
53
54
55
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
56
57
58
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

59
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
70
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
71
72
73
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
74
75
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
76
77
78
79
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
80
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
81

82
        self.qkv_proj = QKVParallelLinear(
83
            embed_dim,
84
85
            self.head_dim,
            total_num_heads,
86
            bias=bias,
87
            linear_method=linear_method,
88
89
90
91
92
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
93
            linear_method=linear_method,
94
        )
95
96
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
97
                                   scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103
104
105

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
106
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        key_cache, value_cache = kv_cache
109
110
        attn_output = self.attn(q, k, v, key_cache, value_cache,
                                input_metadata, cache_event)
Zhuohan Li's avatar
Zhuohan Li committed
111
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
112
113
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
114

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
115
116
class OPTDecoderLayer(nn.Module):

117
118
119
120
121
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
122
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
123
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
124
125
126
127
128
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
129
            linear_method=linear_method,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
130
131
        )
        self.do_layer_norm_before = config.do_layer_norm_before
Woosuk Kwon's avatar
Woosuk Kwon committed
132
        self.activation_fn = get_act_fn(config.activation_function)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
133
134

        self.self_attn_layer_norm = nn.LayerNorm(
135
136
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
137
138
139
140
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
141
            linear_method=linear_method,
142
143
144
145
146
        )
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
147
            linear_method=linear_method,
148
        )
Zhuohan Li's avatar
Zhuohan Li committed
149
        self.final_layer_norm = nn.LayerNorm(
150
151
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
152

Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
157
158
159
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
160
161
162
163
164
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
165
166
167
168
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       input_metadata=input_metadata,
                                       cache_event=cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
169
170
171
172
173
174
175
176
177
178
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
179
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
180
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
181
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
182
183
184
185
186
187
188
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
189
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
190

191
192
193
194
195
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
196
197
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
198
199
200
201
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

202
203
204
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
205
        )
Zhuohan Li's avatar
Zhuohan Li committed
206
207
208
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
209

Zhuohan Li's avatar
Zhuohan Li committed
210
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
211
        if config.word_embed_proj_dim != config.hidden_size:
212
213
214
215
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
                                                linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
216
217
218
219
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
220
221
222
223
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
                                               linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
224
225
226
        else:
            self.project_in = None

227
228
229
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
230
231
232
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
233
234
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
235
236
237
        else:
            self.final_layer_norm = None

238
239
240
241
        self.layers = nn.ModuleList([
            OPTDecoderLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
242
243
244

    def forward(
        self,
245
246
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
250
251
252
253
254
255
256
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
            inputs_embeds = self.project_in(inputs_embeds)
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260
261
262
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
263
264
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
                                  cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
265
266
267
268
269
270
271
272

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
273
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
274

275
276
277
278
279
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
280
        super().__init__()
281
        self.decoder = OPTDecoder(config, linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
282
283
284

    def forward(
        self,
285
286
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
289
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
290
    ) -> torch.Tensor:
291
292
        return self.decoder(input_ids, positions, kv_caches, input_metadata,
                            cache_events)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
293
294


Zhuohan Li's avatar
Zhuohan Li committed
295
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
296

297
298
299
300
301
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
302
303
        super().__init__()
        self.config = config
304
305
        self.linear_method = linear_method
        self.model = OPTModel(config, linear_method)
Zhuohan Li's avatar
Zhuohan Li committed
306
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
Woosuk Kwon's avatar
Woosuk Kwon committed
307
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
308
309
310

    def forward(
        self,
311
312
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
313
314
315
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
316
    ) -> SamplerOutput:
317
318
319
320
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
322

323
324
    def load_weights(self,
                     model_name_or_path: str,
325
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
326
327
                     load_format: str = "auto",
                     revision: Optional[str] = None):
328
329
330
331
332
333
334
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
335
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
336
                model_name_or_path, cache_dir, load_format, revision):
337
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
338
                continue
339
340
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
341
                    continue
342
343
344
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
345
                break
346
347
348
349
350
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)