opt.py 13.4 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.config import CacheConfig
28
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.sampler import Sampler
38
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
43

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
44
45
46
47

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
48
49
50
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
51
52
53
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

54
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
64
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
65
        cache_config: Optional[CacheConfig] = None,
66
        quant_config: Optional[QuantizationConfig] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
67
68
69
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
70
71
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74
75
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
76
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
77

78
        self.qkv_proj = QKVParallelLinear(
79
            embed_dim,
80
81
            self.head_dim,
            total_num_heads,
82
            bias=bias,
83
            quant_config=quant_config,
84
85
86
87
88
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
89
            quant_config=quant_config,
90
        )
91
92
        self.attn = Attention(self.num_heads,
                              self.head_dim,
93
                              scale=self.scaling,
94
95
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99

    def forward(
        self,
        hidden_states: torch.Tensor,
100
101
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
102
    ) -> torch.Tensor:
103
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        q, k, v = qkv.chunk(chunks=3, dim=-1)
105
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
106
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
107
108
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
109

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
110
111
class OPTDecoderLayer(nn.Module):

112
113
114
    def __init__(
        self,
        config: OPTConfig,
115
        cache_config: Optional[CacheConfig] = None,
116
        quant_config: Optional[QuantizationConfig] = None,
117
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
118
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
119
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
120
121
122
123
124
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
125
            cache_config=cache_config,
126
            quant_config=quant_config,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
127
128
129
130
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
131
132
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
133
134
135
136
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
137
            quant_config=quant_config,
138
        )
139
140
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
141
142
143
144
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
145
            quant_config=quant_config,
146
        )
Zhuohan Li's avatar
Zhuohan Li committed
147
        self.final_layer_norm = nn.LayerNorm(
148
149
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
150

Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
    def forward(
        self,
        hidden_states: torch.Tensor,
154
155
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
156
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
157
158
159
160
161
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
162
163
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
164
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
165
166
167
168
169
170
171
172
173
174
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
175
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
176
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
177
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
178
179
180
181
182
183
184
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
185
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
186

187
188
189
    def __init__(
        self,
        config: OPTConfig,
190
        cache_config: Optional[CacheConfig] = None,
191
        quant_config: Optional[QuantizationConfig] = None,
192
    ):
Zhuohan Li's avatar
Zhuohan Li committed
193
194
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
195
196
197
198
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

199
200
201
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
202
        )
Zhuohan Li's avatar
Zhuohan Li committed
203
204
205
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
206

Zhuohan Li's avatar
Zhuohan Li committed
207
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
208
        if config.word_embed_proj_dim != config.hidden_size:
209
210
211
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
212
                                                quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
213
214
215
216
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
217
218
219
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
220
                                               quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
221
222
223
        else:
            self.project_in = None

224
225
226
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
227
228
229
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
230
231
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
232
233
234
        else:
            self.final_layer_norm = None

235
        self.layers = nn.ModuleList([
236
            OPTDecoderLayer(config, cache_config, quant_config)
237
238
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
239
240
241

    def forward(
        self,
242
243
        input_ids: torch.Tensor,
        positions: torch.Tensor,
244
245
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
246
247
248
249
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
250
            inputs_embeds, _ = self.project_in(inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
251
252
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
        for i in range(len(self.layers)):
            layer = self.layers[i]
255
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
256
257
258
259

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
260
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
261
262
263
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
264
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
265

266
267
268
    def __init__(
        self,
        config: OPTConfig,
269
        cache_config: Optional[CacheConfig] = None,
270
        quant_config: Optional[QuantizationConfig] = None,
271
    ):
Zhuohan Li's avatar
Zhuohan Li committed
272
        super().__init__()
273
        self.decoder = OPTDecoder(config, cache_config, quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
274
275
276

    def forward(
        self,
277
278
        input_ids: torch.Tensor,
        positions: torch.Tensor,
279
280
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
281
    ) -> torch.Tensor:
282
        return self.decoder(input_ids, positions, kv_caches, attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
283
284


Zhuohan Li's avatar
Zhuohan Li committed
285
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
286

287
288
289
    def __init__(
        self,
        config,
290
        cache_config: Optional[CacheConfig] = None,
291
        quant_config: Optional[QuantizationConfig] = None,
292
    ):
Zhuohan Li's avatar
Zhuohan Li committed
293
294
        super().__init__()
        self.config = config
295
        self.quant_config = quant_config
296
        self.model = OPTModel(config, cache_config, quant_config)
Zhuohan Li's avatar
Zhuohan Li committed
297
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
298
299
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
300
301
302

    def forward(
        self,
303
304
        input_ids: torch.Tensor,
        positions: torch.Tensor,
305
306
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
307
    ) -> torch.Tensor:
308
        hidden_states = self.model(input_ids, positions, kv_caches,
309
                                   attn_metadata)
310
311
        return hidden_states

312
313
314
315
316
317
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

318
319
    def sample(
        self,
320
        logits: torch.Tensor,
321
        sampling_metadata: SamplingMetadata,
322
    ) -> Optional[SamplerOutput]:
323
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
324
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
325

326
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
327
328
329
330
331
332
333
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
334
        for name, loaded_weight in weights:
335
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
336
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
337
338
339
            if name.startswith("decoder."):
                name = "model." + name

340
341
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
342
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
343
344
345
346
347
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
348
349
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
350
                break
351
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
352
353
354
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
355
356
357
358
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)