opt.py 12.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.model_executor.layers.activation import get_act_fn
29
30
31
32
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
33
from vllm.model_executor.layers.logits_processor import LogitsProcessor
34
35
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.sampler import Sampler
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
42

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
43
44
45
46

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
47
48
49
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
50
51
52
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

53
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
63
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
64
        quant_config: Optional[QuantizationConfig] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
65
66
67
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
68
69
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72
73
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
74
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
75

76
        self.qkv_proj = QKVParallelLinear(
77
            embed_dim,
78
79
            self.head_dim,
            total_num_heads,
80
            bias=bias,
81
            quant_config=quant_config,
82
83
84
85
86
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
87
            quant_config=quant_config,
88
        )
89
90
91
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95

    def forward(
        self,
        hidden_states: torch.Tensor,
96
97
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
    ) -> torch.Tensor:
99
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
        q, k, v = qkv.chunk(chunks=3, dim=-1)
101
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
102
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
103
104
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
105

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
106
107
class OPTDecoderLayer(nn.Module):

108
109
110
    def __init__(
        self,
        config: OPTConfig,
111
        quant_config: Optional[QuantizationConfig] = None,
112
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
113
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
114
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
115
116
117
118
119
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
120
            quant_config=quant_config,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
121
122
123
124
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
125
126
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
127
128
129
130
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
131
            quant_config=quant_config,
132
        )
133
134
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
135
136
137
138
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
139
            quant_config=quant_config,
140
        )
Zhuohan Li's avatar
Zhuohan Li committed
141
        self.final_layer_norm = nn.LayerNorm(
142
143
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
144

Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
    def forward(
        self,
        hidden_states: torch.Tensor,
148
149
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
150
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
151
152
153
154
155
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
156
157
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
158
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
159
160
161
162
163
164
165
166
167
168
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
169
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
170
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
171
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
172
173
174
175
176
177
178
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
179
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
180

181
182
183
    def __init__(
        self,
        config: OPTConfig,
184
        quant_config: Optional[QuantizationConfig] = None,
185
    ):
Zhuohan Li's avatar
Zhuohan Li committed
186
187
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
188
189
190
191
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

192
193
194
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
195
        )
Zhuohan Li's avatar
Zhuohan Li committed
196
197
198
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
199

Zhuohan Li's avatar
Zhuohan Li committed
200
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
201
        if config.word_embed_proj_dim != config.hidden_size:
202
203
204
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
205
                                                quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
206
207
208
209
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
210
211
212
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
213
                                               quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
214
215
216
        else:
            self.project_in = None

217
218
219
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
220
221
222
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
223
224
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
225
226
227
        else:
            self.final_layer_norm = None

228
        self.layers = nn.ModuleList([
229
            OPTDecoderLayer(config, quant_config)
230
231
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
232
233
234

    def forward(
        self,
235
236
        input_ids: torch.Tensor,
        positions: torch.Tensor,
237
238
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
239
240
241
242
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
243
            inputs_embeds, _ = self.project_in(inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
244
245
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
        for i in range(len(self.layers)):
            layer = self.layers[i]
248
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
249
250
251
252

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
253
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
254
255
256
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
257
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
258

259
260
261
    def __init__(
        self,
        config: OPTConfig,
262
        quant_config: Optional[QuantizationConfig] = None,
263
    ):
Zhuohan Li's avatar
Zhuohan Li committed
264
        super().__init__()
265
        self.decoder = OPTDecoder(config, quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
266
267
268

    def forward(
        self,
269
270
        input_ids: torch.Tensor,
        positions: torch.Tensor,
271
272
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
273
    ) -> torch.Tensor:
274
        return self.decoder(input_ids, positions, kv_caches, attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
275
276


Zhuohan Li's avatar
Zhuohan Li committed
277
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
278

279
280
281
    def __init__(
        self,
        config,
282
        quant_config: Optional[QuantizationConfig] = None,
283
    ):
Zhuohan Li's avatar
Zhuohan Li committed
284
285
        super().__init__()
        self.config = config
286
287
        self.quant_config = quant_config
        self.model = OPTModel(config, quant_config)
Zhuohan Li's avatar
Zhuohan Li committed
288
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
289
290
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
291
292
293

    def forward(
        self,
294
295
        input_ids: torch.Tensor,
        positions: torch.Tensor,
296
297
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
298
    ) -> torch.Tensor:
299
        hidden_states = self.model(input_ids, positions, kv_caches,
300
                                   attn_metadata)
301
302
        return hidden_states

303
304
305
306
307
308
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

309
310
    def sample(
        self,
311
        logits: torch.Tensor,
312
        sampling_metadata: SamplingMetadata,
313
    ) -> Optional[SamplerOutput]:
314
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
315
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
316

317
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
318
319
320
321
322
323
324
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
325
        for name, loaded_weight in weights:
326
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
327
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
330
            if name.startswith("decoder."):
                name = "model." + name

331
332
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
333
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
334
335
336
337
338
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
339
340
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
341
                break
342
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
343
344
345
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
346
347
348
349
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)