opt.py 13.3 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

26
from vllm.attention import Attention, AttentionMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.model_executor.layers.activation import get_act_fn
28
29
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.sampler import Sampler
35
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.parallel_utils.parallel_state import (
38
    get_tensor_model_parallel_world_size)
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
41
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
42
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
43

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
44
45
46
47

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
48
49
50
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
51
52
53
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

54
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
64
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
65
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
66
67
68
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
69
70
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
71
72
73
74
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
75
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
76

77
        self.qkv_proj = QKVParallelLinear(
78
            embed_dim,
79
80
            self.head_dim,
            total_num_heads,
81
            bias=bias,
82
            linear_method=linear_method,
83
84
85
86
87
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
88
            linear_method=linear_method,
89
        )
90
91
92
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
96

    def forward(
        self,
        hidden_states: torch.Tensor,
97
98
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
99
    ) -> torch.Tensor:
100
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        q, k, v = qkv.chunk(chunks=3, dim=-1)
102
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
103
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
104
105
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
106

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
107
108
class OPTDecoderLayer(nn.Module):

109
110
111
112
113
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
114
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
115
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
116
117
118
119
120
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
121
            linear_method=linear_method,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
122
123
124
125
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
126
127
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
128
129
130
131
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
132
            linear_method=linear_method,
133
        )
134
135
136
        quant_config = getattr(linear_method, "quant_config", None)
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
137
138
139
140
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
141
            linear_method=linear_method,
142
        )
Zhuohan Li's avatar
Zhuohan Li committed
143
        self.final_layer_norm = nn.LayerNorm(
144
145
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
146

Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
    def forward(
        self,
        hidden_states: torch.Tensor,
150
151
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
152
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
153
154
155
156
157
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
158
159
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
160
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
161
162
163
164
165
166
167
168
169
170
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
171
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
172
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
173
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
174
175
176
177
178
179
180
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
181
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
182

183
184
185
186
187
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
188
189
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
190
191
192
193
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

194
195
196
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
197
        )
Zhuohan Li's avatar
Zhuohan Li committed
198
199
200
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
201

Zhuohan Li's avatar
Zhuohan Li committed
202
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
203
        if config.word_embed_proj_dim != config.hidden_size:
204
205
206
207
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
                                                linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
208
209
210
211
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
212
213
214
215
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
                                               linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
216
217
218
        else:
            self.project_in = None

219
220
221
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
222
223
224
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
225
226
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
227
228
229
        else:
            self.final_layer_norm = None

230
231
232
233
        self.layers = nn.ModuleList([
            OPTDecoderLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
234
235
236

    def forward(
        self,
237
238
        input_ids: torch.Tensor,
        positions: torch.Tensor,
239
240
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
241
242
243
244
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
245
            inputs_embeds, _ = self.project_in(inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
246
247
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
        for i in range(len(self.layers)):
            layer = self.layers[i]
250
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
251
252
253
254

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
255
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
256
257
258
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
259
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
260

261
262
263
264
265
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
266
        super().__init__()
267
        self.decoder = OPTDecoder(config, linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
268
269
270

    def forward(
        self,
271
272
        input_ids: torch.Tensor,
        positions: torch.Tensor,
273
274
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
275
    ) -> torch.Tensor:
276
        return self.decoder(input_ids, positions, kv_caches, attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
277
278


Zhuohan Li's avatar
Zhuohan Li committed
279
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
280

281
282
283
284
285
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
286
287
        super().__init__()
        self.config = config
288
289
        self.linear_method = linear_method
        self.model = OPTModel(config, linear_method)
Zhuohan Li's avatar
Zhuohan Li committed
290
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
291
292
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
293
294
295

    def forward(
        self,
296
297
        input_ids: torch.Tensor,
        positions: torch.Tensor,
298
299
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
300
    ) -> torch.Tensor:
301
        hidden_states = self.model(input_ids, positions, kv_caches,
302
                                   attn_metadata)
303
304
        return hidden_states

305
306
307
308
309
310
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

311
312
    def sample(
        self,
313
        logits: torch.Tensor,
314
        sampling_metadata: SamplingMetadata,
315
    ) -> Optional[SamplerOutput]:
316
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
317
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
318

319
320
    def load_weights(self,
                     model_name_or_path: str,
321
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
322
323
                     load_format: str = "auto",
                     revision: Optional[str] = None):
324
325
326
327
328
329
330
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
331
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
332
                model_name_or_path, cache_dir, load_format, revision):
333
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
334
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
            if name.startswith("decoder."):
                name = "model." + name

338
339
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
340
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
341
342
343
344
345
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
346
347
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
348
                break
349
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
350
351
352
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
353
354
355
356
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)