opt.py 12.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
20
21
22
23
"""Inference-only OPT model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
24
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
26
27
28
29
import torch
from torch import nn
from transformers import OPTConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33
34
35
36
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
                                              load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
Zhuohan Li's avatar
Zhuohan Li committed
37
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
38
39
40
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
                                                       ColumnParallelLinear,
                                                       RowParallelLinear)
41
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44

KVCache = Tuple[torch.Tensor, torch.Tensor]

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
45
46
47
48

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
49
50
51
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
52
53
54
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

55
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
56
57
58
59
60
61
62
63
64
65
66
67
68
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
69
70
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
71
72
73
74
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
75
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
76

77
78
79
80
81
82
83
84
85
86
87
88
        self.qkv_proj = ColumnParallelLinear(
            embed_dim,
            3 * embed_dim,
            bias=bias,
            gather_output=False,
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
            input_is_parallel=True,
        )
89
90
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
                                   scale=self.scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
96
97
98
99

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
100
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
102
        key_cache, value_cache = kv_cache
103
104
        attn_output = self.attn(q, k, v, key_cache, value_cache,
                                input_metadata, cache_event)
Zhuohan Li's avatar
Zhuohan Li committed
105
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
106
107
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
108

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
109
110
111
112
class OPTDecoderLayer(nn.Module):

    def __init__(self, config: OPTConfig):
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
113
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
114
115
116
117
118
119
120
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
        )
        self.do_layer_norm_before = config.do_layer_norm_before
Woosuk Kwon's avatar
Woosuk Kwon committed
121
        self.activation_fn = get_act_fn(config.activation_function)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
122
123

        self.self_attn_layer_norm = nn.LayerNorm(
124
125
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
126
127
128
129
130
131
132
133
134
135
136
137
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
            gather_output=False,
        )
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
            input_is_parallel=True,
        )
Zhuohan Li's avatar
Zhuohan Li committed
138
        self.final_layer_norm = nn.LayerNorm(
139
140
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
141

Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
147
148
    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
149
150
151
152
153
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
154
155
156
157
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       input_metadata=input_metadata,
                                       cache_event=cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
158
159
160
161
162
163
164
165
166
167
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
168
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
169
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
170
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
171
172
173
174
175
176
177
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
178
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
179
180

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
181
182
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
183
184
185
186
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

187
188
189
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
190
        )
Zhuohan Li's avatar
Zhuohan Li committed
191
192
193
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
194

Zhuohan Li's avatar
Zhuohan Li committed
195
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
196
        if config.word_embed_proj_dim != config.hidden_size:
197
198
199
            self.project_out = nn.Linear(config.hidden_size,
                                         config.word_embed_proj_dim,
                                         bias=False)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
200
201
202
203
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
204
205
206
            self.project_in = nn.Linear(config.word_embed_proj_dim,
                                        config.hidden_size,
                                        bias=False)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
207
208
209
        else:
            self.project_in = None

210
211
212
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
213
214
215
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
216
217
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
218
219
220
        else:
            self.final_layer_norm = None

221
222
        self.layers = nn.ModuleList(
            [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
223
224
225

    def forward(
        self,
226
227
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
231
232
233
234
235
236
237
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(positions)
        if self.project_in is not None:
            inputs_embeds = self.project_in(inputs_embeds)
        hidden_states = inputs_embeds + pos_embeds

Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
240
241
242
243
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
244
245
            hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
                                  cache_event)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
246
247
248
249
250
251
252
253

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
254
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
255
256

    def __init__(self, config: OPTConfig):
Zhuohan Li's avatar
Zhuohan Li committed
257
        super().__init__()
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
258
259
260
261
        self.decoder = OPTDecoder(config)

    def forward(
        self,
262
263
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
267
    ) -> torch.Tensor:
268
269
        return self.decoder(input_ids, positions, kv_caches, input_metadata,
                            cache_events)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
270
271


Zhuohan Li's avatar
Zhuohan Li committed
272
class OPTForCausalLM(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
273
274

    def __init__(self, config):
Zhuohan Li's avatar
Zhuohan Li committed
275
276
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
277
        self.model = OPTModel(config)
Zhuohan Li's avatar
Zhuohan Li committed
278
279
280
        # TODO(zhuohan): create a new weight after implementing pipeline
        #                parallelism
        self.lm_head_weight = self.model.decoder.embed_tokens.weight
Woosuk Kwon's avatar
Woosuk Kwon committed
281
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
282
283
284

    def forward(
        self,
285
286
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
289
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
290
    ) -> SamplerOutput:
291
292
293
294
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
295
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
296

297
298
299
    _column_parallel_weights = [
        "embed_tokens.weight", "fc1.weight", "fc1.bias"
    ]
Zhuohan Li's avatar
Zhuohan Li committed
300
301
    _row_parallel_weights = ["out_proj.weight", "fc2.weight"]

302
303
    def load_weights(self,
                     model_name_or_path: str,
304
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
305
306
                     load_format: str = "auto",
                     revision: Optional[str] = None):
Zhuohan Li's avatar
Zhuohan Li committed
307
308
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()
309
310

        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
311
                model_name_or_path, cache_dir, load_format, revision):
312
            if "lm_head.weight" in name:
Zhuohan Li's avatar
Zhuohan Li committed
313
                continue
314
315
316
317
318

            if name.startswith("decoder."):
                name = "model." + name

            is_attention_weight = False
319
320
            for stride_id, att_weight_name in enumerate(
                ["q_proj", "k_proj", "v_proj"]):
321
322
323
                if att_weight_name not in name:
                    continue
                param = state_dict[name.replace(att_weight_name, "qkv_proj")]
324
                shard_size = param.shape[0] // 3
325
                loaded_weight = loaded_weight[
326
327
328
329
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
                param_slice = param.data[shard_size * stride_id:shard_size *
                                         (stride_id + 1)]
330
331
332
333
334
335
336
337
338
339
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_attention_weight = True
                break
            if is_attention_weight:
                continue

            param = state_dict[name]
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
340
341
                                         self._row_parallel_weights,
                                         tensor_model_parallel_rank)