opt.py 13.3 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.config import CacheConfig
28
from vllm.distributed import get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.sampler import Sampler
38
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
43

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
44
45
46
47

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
48
49
50
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
51
52
53
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

54
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
64
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
65
        cache_config: Optional[CacheConfig] = None,
66
        quant_config: Optional[QuantizationConfig] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
67
68
69
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
70
71
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74
75
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
76
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
77

78
        self.qkv_proj = QKVParallelLinear(
79
            embed_dim,
80
81
            self.head_dim,
            total_num_heads,
82
            bias=bias,
83
            quant_config=quant_config,
84
85
86
87
88
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
89
            quant_config=quant_config,
90
        )
91
92
        self.attn = Attention(self.num_heads,
                              self.head_dim,
93
94
                              scale=self.scaling,
                              cache_config=cache_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98

    def forward(
        self,
        hidden_states: torch.Tensor,
99
100
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
101
    ) -> torch.Tensor:
102
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        q, k, v = qkv.chunk(chunks=3, dim=-1)
104
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
105
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
106
107
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
108

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
109
110
class OPTDecoderLayer(nn.Module):

111
112
113
    def __init__(
        self,
        config: OPTConfig,
114
        cache_config: Optional[CacheConfig] = None,
115
        quant_config: Optional[QuantizationConfig] = None,
116
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
117
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
118
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
119
120
121
122
123
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
124
            cache_config=cache_config,
125
            quant_config=quant_config,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
126
127
128
129
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
130
131
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
132
133
134
135
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
136
            quant_config=quant_config,
137
        )
138
139
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
140
141
142
143
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
144
            quant_config=quant_config,
145
        )
Zhuohan Li's avatar
Zhuohan Li committed
146
        self.final_layer_norm = nn.LayerNorm(
147
148
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
149

Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
152
    def forward(
        self,
        hidden_states: torch.Tensor,
153
154
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
156
157
158
159
160
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
161
162
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
163
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
164
165
166
167
168
169
170
171
172
173
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
174
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
175
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
176
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
177
178
179
180
181
182
183
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
184
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
185

186
187
188
    def __init__(
        self,
        config: OPTConfig,
189
        cache_config: Optional[CacheConfig] = None,
190
        quant_config: Optional[QuantizationConfig] = None,
191
    ):
Zhuohan Li's avatar
Zhuohan Li committed
192
193
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
194
195
196
197
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

198
199
200
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
201
        )
Zhuohan Li's avatar
Zhuohan Li committed
202
203
204
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
205

Zhuohan Li's avatar
Zhuohan Li committed
206
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
207
        if config.word_embed_proj_dim != config.hidden_size:
208
209
210
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
211
                                                quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
212
213
214
215
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
216
217
218
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
219
                                               quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
220
221
222
        else:
            self.project_in = None

223
224
225
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
226
227
228
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
229
230
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
231
232
233
        else:
            self.final_layer_norm = None

234
        self.layers = nn.ModuleList([
235
            OPTDecoderLayer(config, cache_config, quant_config)
236
237
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
238
239
240

    def forward(
        self,
241
242
        input_ids: torch.Tensor,
        positions: torch.Tensor,
243
244
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
245
246
247
248
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
249
            inputs_embeds, _ = self.project_in(inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
250
251
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
        for i in range(len(self.layers)):
            layer = self.layers[i]
254
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
255
256
257
258

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
259
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
260
261
262
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
263
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
264

265
266
267
    def __init__(
        self,
        config: OPTConfig,
268
        cache_config: Optional[CacheConfig] = None,
269
        quant_config: Optional[QuantizationConfig] = None,
270
    ):
Zhuohan Li's avatar
Zhuohan Li committed
271
        super().__init__()
272
        self.decoder = OPTDecoder(config, cache_config, quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
273
274
275

    def forward(
        self,
276
277
        input_ids: torch.Tensor,
        positions: torch.Tensor,
278
279
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
280
    ) -> torch.Tensor:
281
        return self.decoder(input_ids, positions, kv_caches, attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
282
283


Zhuohan Li's avatar
Zhuohan Li committed
284
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
285

286
287
288
    def __init__(
        self,
        config,
289
        cache_config: Optional[CacheConfig] = None,
290
        quant_config: Optional[QuantizationConfig] = None,
291
    ):
Zhuohan Li's avatar
Zhuohan Li committed
292
293
        super().__init__()
        self.config = config
294
        self.quant_config = quant_config
295
        self.model = OPTModel(config, cache_config, quant_config)
Zhuohan Li's avatar
Zhuohan Li committed
296
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
297
298
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
299
300
301

    def forward(
        self,
302
303
        input_ids: torch.Tensor,
        positions: torch.Tensor,
304
305
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
306
    ) -> torch.Tensor:
307
        hidden_states = self.model(input_ids, positions, kv_caches,
308
                                   attn_metadata)
309
310
        return hidden_states

311
312
313
314
315
316
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

317
318
    def sample(
        self,
319
        logits: torch.Tensor,
320
        sampling_metadata: SamplingMetadata,
321
    ) -> Optional[SamplerOutput]:
322
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
323
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
324

325
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
326
327
328
329
330
331
332
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
333
        for name, loaded_weight in weights:
334
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
335
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
336
337
338
            if name.startswith("decoder."):
                name = "model." + name

339
340
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
341
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
342
343
344
345
346
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
347
348
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
349
                break
350
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
351
352
353
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
354
355
356
357
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)