opt.py 13.2 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.sampler import Sampler
36
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
40
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
41
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
42

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
43
44
45
46

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
47
48
49
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
50
51
52
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

53
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
63
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
64
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
65
66
67
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
68
69
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72
73
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
74
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
75

76
        self.qkv_proj = QKVParallelLinear(
77
            embed_dim,
78
79
            self.head_dim,
            total_num_heads,
80
            bias=bias,
81
            linear_method=linear_method,
82
83
84
85
86
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
87
            linear_method=linear_method,
88
        )
89
90
91
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95

    def forward(
        self,
        hidden_states: torch.Tensor,
96
97
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
    ) -> torch.Tensor:
99
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
        q, k, v = qkv.chunk(chunks=3, dim=-1)
101
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
102
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
103
104
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
105

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
106
107
class OPTDecoderLayer(nn.Module):

108
109
110
111
112
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
113
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
114
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
115
116
117
118
119
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
120
            linear_method=linear_method,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
121
122
123
124
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
125
126
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
127
128
129
130
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
131
            linear_method=linear_method,
132
        )
133
134
135
        quant_config = getattr(linear_method, "quant_config", None)
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
136
137
138
139
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
140
            linear_method=linear_method,
141
        )
Zhuohan Li's avatar
Zhuohan Li committed
142
        self.final_layer_norm = nn.LayerNorm(
143
144
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
145

Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
    def forward(
        self,
        hidden_states: torch.Tensor,
149
150
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
151
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
152
153
154
155
156
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
157
158
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
159
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
160
161
162
163
164
165
166
167
168
169
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
170
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
171
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
172
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
173
174
175
176
177
178
179
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
180
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
181

182
183
184
185
186
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
187
188
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
189
190
191
192
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

193
194
195
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
196
        )
Zhuohan Li's avatar
Zhuohan Li committed
197
198
199
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
200

Zhuohan Li's avatar
Zhuohan Li committed
201
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
202
        if config.word_embed_proj_dim != config.hidden_size:
203
204
205
206
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
                                                linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
207
208
209
210
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
211
212
213
214
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
                                               linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
215
216
217
        else:
            self.project_in = None

218
219
220
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
221
222
223
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
224
225
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
226
227
228
        else:
            self.final_layer_norm = None

229
230
231
232
        self.layers = nn.ModuleList([
            OPTDecoderLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
233
234
235

    def forward(
        self,
236
237
        input_ids: torch.Tensor,
        positions: torch.Tensor,
238
239
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
240
241
242
243
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
244
            inputs_embeds, _ = self.project_in(inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
245
246
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
        for i in range(len(self.layers)):
            layer = self.layers[i]
249
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
250
251
252
253

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
254
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
255
256
257
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
258
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
259

260
261
262
263
264
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
265
        super().__init__()
266
        self.decoder = OPTDecoder(config, linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
267
268
269

    def forward(
        self,
270
271
        input_ids: torch.Tensor,
        positions: torch.Tensor,
272
273
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
274
    ) -> torch.Tensor:
275
        return self.decoder(input_ids, positions, kv_caches, attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
276
277


Zhuohan Li's avatar
Zhuohan Li committed
278
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
279

280
281
282
283
284
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
285
286
        super().__init__()
        self.config = config
287
288
        self.linear_method = linear_method
        self.model = OPTModel(config, linear_method)
Zhuohan Li's avatar
Zhuohan Li committed
289
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
290
291
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
292
293
294

    def forward(
        self,
295
296
        input_ids: torch.Tensor,
        positions: torch.Tensor,
297
298
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
299
    ) -> torch.Tensor:
300
        hidden_states = self.model(input_ids, positions, kv_caches,
301
                                   attn_metadata)
302
303
        return hidden_states

304
305
306
307
308
309
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

310
311
    def sample(
        self,
312
        logits: torch.Tensor,
313
        sampling_metadata: SamplingMetadata,
314
    ) -> Optional[SamplerOutput]:
315
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
316
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
317

318
319
    def load_weights(self,
                     model_name_or_path: str,
320
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
321
322
                     load_format: str = "auto",
                     revision: Optional[str] = None):
323
324
325
326
327
328
329
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
330
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
331
                model_name_or_path, cache_dir, load_format, revision):
332
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
333
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336
            if name.startswith("decoder."):
                name = "model." + name

337
338
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
339
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
340
341
342
343
344
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
345
346
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
347
                break
348
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
349
350
351
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
352
353
354
355
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)