opt.py 13.4 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
28
from vllm.model_executor.layers.attention import Attention
29
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.sampler import Sampler
36
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.parallel_utils.parallel_state import (
39
    get_tensor_model_parallel_world_size)
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
42
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
43
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
47
48
49
50

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
51
52
53
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
54
55
56
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

57
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
58
59
60
61
62
63
64
65
66
67
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
68
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
69
70
71
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
72
73
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
74
75
76
77
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
78
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
79

80
        self.qkv_proj = QKVParallelLinear(
81
            embed_dim,
82
83
            self.head_dim,
            total_num_heads,
84
            bias=bias,
85
            linear_method=linear_method,
86
87
88
89
90
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
91
            linear_method=linear_method,
92
        )
93
94
95
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
102

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
103
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
105
        key_cache, value_cache = kv_cache
106
        attn_output = self.attn(q, k, v, key_cache, value_cache,
107
                                input_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
108
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
109
110
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
111

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
112
113
class OPTDecoderLayer(nn.Module):

114
115
116
117
118
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
119
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
120
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
121
122
123
124
125
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
126
            linear_method=linear_method,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
127
128
129
130
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
131
132
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
133
134
135
136
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
137
            linear_method=linear_method,
138
        )
139
140
141
        quant_config = getattr(linear_method, "quant_config", None)
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
142
143
144
145
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
146
            linear_method=linear_method,
147
        )
Zhuohan Li's avatar
Zhuohan Li committed
148
        self.final_layer_norm = nn.LayerNorm(
149
150
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
151

Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
155
156
157
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
158
159
160
161
162
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
163
164
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
165
                                       input_metadata=input_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
166
167
168
169
170
171
172
173
174
175
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
176
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
177
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
178
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
179
180
181
182
183
184
185
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
186
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
187

188
189
190
191
192
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
193
194
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
195
196
197
198
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

199
200
201
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
202
        )
Zhuohan Li's avatar
Zhuohan Li committed
203
204
205
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
206

Zhuohan Li's avatar
Zhuohan Li committed
207
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
208
        if config.word_embed_proj_dim != config.hidden_size:
209
210
211
212
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
                                                linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
213
214
215
216
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
217
218
219
220
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
                                               linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
221
222
223
        else:
            self.project_in = None

224
225
226
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
227
228
229
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
230
231
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
232
233
234
        else:
            self.final_layer_norm = None

235
236
237
238
        self.layers = nn.ModuleList([
            OPTDecoderLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
239
240
241

    def forward(
        self,
242
243
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
246
247
248
249
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
250
            inputs_embeds, _ = self.project_in(inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
251
252
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
        for i in range(len(self.layers)):
            layer = self.layers[i]
255
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
256
257
258
259

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
260
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
261
262
263
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
264
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
265

266
267
268
269
270
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
271
        super().__init__()
272
        self.decoder = OPTDecoder(config, linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
273
274
275

    def forward(
        self,
276
277
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
280
    ) -> torch.Tensor:
281
        return self.decoder(input_ids, positions, kv_caches, input_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
282
283


Zhuohan Li's avatar
Zhuohan Li committed
284
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
285

286
287
288
289
290
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
291
292
        super().__init__()
        self.config = config
293
294
        self.linear_method = linear_method
        self.model = OPTModel(config, linear_method)
Zhuohan Li's avatar
Zhuohan Li committed
295
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
296
297
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
298
299
300

    def forward(
        self,
301
302
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
305
    ) -> torch.Tensor:
306
        hidden_states = self.model(input_ids, positions, kv_caches,
307
                                   input_metadata)
308
309
        return hidden_states

310
311
312
313
314
315
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

316
317
    def sample(
        self,
318
        logits: torch.Tensor,
319
        sampling_metadata: SamplingMetadata,
320
    ) -> Optional[SamplerOutput]:
321
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
322
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
323

324
325
    def load_weights(self,
                     model_name_or_path: str,
326
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
327
328
                     load_format: str = "auto",
                     revision: Optional[str] = None):
329
330
331
332
333
334
335
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
336
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
337
                model_name_or_path, cache_dir, load_format, revision):
338
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
339
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341
342
            if name.startswith("decoder."):
                name = "model." + name

343
344
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
345
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
346
347
348
349
350
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
351
352
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
353
                break
354
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
355
356
357
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
358
359
360
361
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)