opt.py 13.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
29
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.sampler import Sampler
35
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.parallel_utils.parallel_state import (
38
    get_tensor_model_parallel_world_size)
39
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
41
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
42
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
46
47
48
49

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
50
51
52
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
53
54
55
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

56
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
57
58
59
60
61
62
63
64
65
66
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
67
        linear_method: Optional[LinearMethodBase] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
68
69
70
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
71
72
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
73
74
75
76
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
77
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
78

79
        self.qkv_proj = QKVParallelLinear(
80
            embed_dim,
81
82
            self.head_dim,
            total_num_heads,
83
            bias=bias,
84
            linear_method=linear_method,
85
86
87
88
89
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
90
            linear_method=linear_method,
91
        )
92
93
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
94
                                   scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98
99
100
101
102

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
103
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
105
        key_cache, value_cache = kv_cache
106
107
        attn_output = self.attn(q, k, v, key_cache, value_cache,
                                input_metadata, cache_event)
Zhuohan Li's avatar
Zhuohan Li committed
108
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
109
110
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
111

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
112
113
class OPTDecoderLayer(nn.Module):

114
115
116
117
118
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
119
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
120
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
121
122
123
124
125
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
126
            linear_method=linear_method,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
127
128
129
130
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
131
132
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
133
134
135
136
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
137
            linear_method=linear_method,
138
        )
139
140
141
        quant_config = getattr(linear_method, "quant_config", None)
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
142
143
144
145
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
146
            linear_method=linear_method,
147
        )
Zhuohan Li's avatar
Zhuohan Li committed
148
        self.final_layer_norm = nn.LayerNorm(
149
150
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
151

Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
155
156
157
158
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
159
160
161
162
163
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
164
165
166
167
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       input_metadata=input_metadata,
                                       cache_event=cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
168
169
170
171
172
173
174
175
176
177
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
178
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
179
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
180
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
181
182
183
184
185
186
187
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
188
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
189

190
191
192
193
194
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
195
196
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
197
198
199
200
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

201
202
203
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
204
        )
Zhuohan Li's avatar
Zhuohan Li committed
205
206
207
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
208

Zhuohan Li's avatar
Zhuohan Li committed
209
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
210
        if config.word_embed_proj_dim != config.hidden_size:
211
212
213
214
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
                                                linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
215
216
217
218
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
219
220
221
222
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
                                               linear_method=linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
223
224
225
        else:
            self.project_in = None

226
227
228
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
229
230
231
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
232
233
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
234
235
236
        else:
            self.final_layer_norm = None

237
238
239
240
        self.layers = nn.ModuleList([
            OPTDecoderLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
241
242
243

    def forward(
        self,
244
245
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
249
250
251
252
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
253
            inputs_embeds, _ = self.project_in(inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
254
255
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
256
        for i in range(len(self.layers)):
257
            cache_event = None if cache_events is None else cache_events[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
258
            layer = self.layers[i]
259
260
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
                                  cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
261
262
263
264

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
265
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
266
267
268
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
269
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
270

271
272
273
274
275
    def __init__(
        self,
        config: OPTConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
276
        super().__init__()
277
        self.decoder = OPTDecoder(config, linear_method)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
278
279
280

    def forward(
        self,
281
282
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
285
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
286
    ) -> torch.Tensor:
287
288
        return self.decoder(input_ids, positions, kv_caches, input_metadata,
                            cache_events)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
289
290


Zhuohan Li's avatar
Zhuohan Li committed
291
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
292

293
294
295
296
297
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Zhuohan Li's avatar
Zhuohan Li committed
298
299
        super().__init__()
        self.config = config
300
301
        self.linear_method = linear_method
        self.model = OPTModel(config, linear_method)
Zhuohan Li's avatar
Zhuohan Li committed
302
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
Woosuk Kwon's avatar
Woosuk Kwon committed
303
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
304
305
306

    def forward(
        self,
307
308
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
311
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
312
    ) -> torch.Tensor:
313
314
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
315
316
317
318
319
320
321
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
322
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
323
                                   sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
324
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
325

326
327
    def load_weights(self,
                     model_name_or_path: str,
328
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
329
330
                     load_format: str = "auto",
                     revision: Optional[str] = None):
331
332
333
334
335
336
337
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
338
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
339
                model_name_or_path, cache_dir, load_format, revision):
340
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
341
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
344
            if name.startswith("decoder."):
                name = "model." + name

345
346
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
347
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
348
349
350
351
352
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
353
354
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
355
                break
356
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
357
358
359
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
360
361
362
363
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)