opt.py 13.1 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
29
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.sampler import Sampler
35
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.parallel_utils.parallel_state import (
38
    get_tensor_model_parallel_world_size)
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
41
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
42
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
46
47
48
49

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
50
51
52
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
53
54
55
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

56
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
57
58
59
60
61
62
63
64
65
66
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
67
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
68
69
70
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
71
72
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
73
74
75
76
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
77
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
78

79
        self.qkv_proj = QKVParallelLinear(
80
            embed_dim,
81
82
            self.head_dim,
            total_num_heads,
83
            bias=bias,
84
            linear_method=linear_method,
85
86
87
88
89
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
90
            linear_method=linear_method,
91
        )
92
93
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
                                   scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98
99
100
101

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
102
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        key_cache, value_cache = kv_cache
105
        attn_output = self.attn(q, k, v, key_cache, value_cache,
106
                                input_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
107
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
108
109
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
110

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
111
112
class OPTDecoderLayer(nn.Module):

113
114
115
116
117
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
118
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
119
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
120
121
122
123
124
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
125
            linear_method=linear_method,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
126
127
128
129
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
130
131
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
132
133
134
135
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
136
            linear_method=linear_method,
137
        )
138
139
140
        quant_config = getattr(linear_method, "quant_config", None)
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
141
142
143
144
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
145
            linear_method=linear_method,
146
        )
Zhuohan Li's avatar
Zhuohan Li committed
147
        self.final_layer_norm = nn.LayerNorm(
148
149
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
150

Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
154
155
156
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
157
158
159
160
161
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
162
163
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
164
                                       input_metadata=input_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
165
166
167
168
169
170
171
172
173
174
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
175
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
176
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
177
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
178
179
180
181
182
183
184
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
185
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
186

187
188
189
190
191
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
192
193
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
194
195
196
197
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

198
199
200
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
201
        )
Zhuohan Li's avatar
Zhuohan Li committed
202
203
204
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
205

Zhuohan Li's avatar
Zhuohan Li committed
206
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
207
        if config.word_embed_proj_dim != config.hidden_size:
208
209
210
211
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
                                                linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
212
213
214
215
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
216
217
218
219
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
                                               linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
220
221
222
        else:
            self.project_in = None

223
224
225
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
226
227
228
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
229
230
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
231
232
233
        else:
            self.final_layer_norm = None

234
235
236
237
        self.layers = nn.ModuleList([
            OPTDecoderLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
238
239
240

    def forward(
        self,
241
242
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
245
246
247
248
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
249
            inputs_embeds, _ = self.project_in(inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
250
251
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
        for i in range(len(self.layers)):
            layer = self.layers[i]
254
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
255
256
257
258

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
259
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
260
261
262
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
263
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
264

265
266
267
268
269
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
270
        super().__init__()
271
        self.decoder = OPTDecoder(config, linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
272
273
274

    def forward(
        self,
275
276
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
277
278
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
279
    ) -> torch.Tensor:
280
        return self.decoder(input_ids, positions, kv_caches, input_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
281
282


Zhuohan Li's avatar
Zhuohan Li committed
283
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
284

285
286
287
288
289
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
290
291
        super().__init__()
        self.config = config
292
293
        self.linear_method = linear_method
        self.model = OPTModel(config, linear_method)
Zhuohan Li's avatar
Zhuohan Li committed
294
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
Woosuk Kwon's avatar
Woosuk Kwon committed
295
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
296
297
298

    def forward(
        self,
299
300
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
303
    ) -> torch.Tensor:
304
        hidden_states = self.model(input_ids, positions, kv_caches,
305
                                   input_metadata)
306
307
308
309
310
311
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
312
    ) -> Optional[SamplerOutput]:
313
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
314
                                   sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
315
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
316

317
318
    def load_weights(self,
                     model_name_or_path: str,
319
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
320
321
                     load_format: str = "auto",
                     revision: Optional[str] = None):
322
323
324
325
326
327
328
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
329
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
330
                model_name_or_path, cache_dir, load_format, revision):
331
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
332
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
333
334
335
            if name.startswith("decoder."):
                name = "model." + name

336
337
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
338
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
339
340
341
342
343
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
344
345
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
346
                break
347
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
348
349
350
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
351
352
353
354
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)