opt.py 16.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
21
"""Inference-only OPT model compatible with HuggingFace weights."""
22
from collections.abc import Iterable
23
from itertools import islice
24
from typing import Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
26
27
28
29
import torch
from torch import nn
from transformers import OPTConfig

30
from vllm.attention import Attention
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import get_act_fn
35
36
37
38
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    ParallelLMHead, VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
45

46
from .interfaces import SupportsPP
47
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
48
49
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
50

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
51
52
53
54

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
55
56
57
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
58
59
60
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

61
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
62
63
64
65
66
67
68
69
70
71
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
72
        cache_config: Optional[CacheConfig] = None,
73
        quant_config: Optional[QuantizationConfig] = None,
74
        prefix: str = "",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
75
76
77
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
78
79
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
80
81
82
83
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
84
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
85

86
        self.qkv_proj = QKVParallelLinear(
87
            embed_dim,
88
89
            self.head_dim,
            total_num_heads,
90
            bias=bias,
91
            quant_config=quant_config,
92
            prefix=f"{prefix}.qkv_proj",
93
94
95
96
97
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
98
            quant_config=quant_config,
99
            prefix=f"{prefix}.out_proj",
100
        )
101
102
        self.attn = Attention(self.num_heads,
                              self.head_dim,
103
                              scale=self.scaling,
104
                              cache_config=cache_config,
105
106
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
109
110
111

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
112
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
        q, k, v = qkv.chunk(chunks=3, dim=-1)
114
        attn_output = self.attn(q, k, v)
Zhuohan Li's avatar
Zhuohan Li committed
115
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
116
117
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
118

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
119
120
class OPTDecoderLayer(nn.Module):

121
122
123
    def __init__(
        self,
        config: OPTConfig,
124
        cache_config: Optional[CacheConfig] = None,
125
        quant_config: Optional[QuantizationConfig] = None,
126
        prefix: str = "",
127
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
128
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
129
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
130
131
132
133
134
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
135
            cache_config=cache_config,
136
            quant_config=quant_config,
137
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
138
139
140
141
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
142
143
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
144
145
146
147
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
148
            quant_config=quant_config,
149
            prefix=f"{prefix}.fc1",
150
        )
151
        self.activation_fn = get_act_fn(config.activation_function)
152
153
154
155
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
156
            quant_config=quant_config,
157
            prefix=f"{prefix}.fc2",
158
        )
Zhuohan Li's avatar
Zhuohan Li committed
159
        self.final_layer_norm = nn.LayerNorm(
160
161
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
162

Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
166
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
167
168
169
170
171
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
172
        hidden_states = self.self_attn(hidden_states=hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
173
174
175
176
177
178
179
180
181
182
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
183
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
184
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
185
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
186
187
188
189
190
191
192
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
193
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
194

195
196
197
    def __init__(
        self,
        config: OPTConfig,
198
        cache_config: Optional[CacheConfig] = None,
199
        quant_config: Optional[QuantizationConfig] = None,
200
        prefix: str = "",
201
    ):
Zhuohan Li's avatar
Zhuohan Li committed
202
203
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
204
205
206
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

207
208
209
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
210
        )
Zhuohan Li's avatar
Zhuohan Li committed
211
212
213
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
214

Zhuohan Li's avatar
Zhuohan Li committed
215
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
216
        if config.word_embed_proj_dim != config.hidden_size:
217
218
219
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
220
221
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
222
223
224
225
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
226
227
228
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
229
230
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
231
232
233
        else:
            self.project_in = None

234
235
236
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
237
238
239
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
240
241
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
242
243
244
        else:
            self.final_layer_norm = None

245
246
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
247
248
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
249
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
250

251
252
253
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
254
255
    def forward(
        self,
256
257
        input_ids: torch.Tensor,
        positions: torch.Tensor,
258
        intermediate_tensors: Optional[IntermediateTensors],
259
        inputs_embeds: Optional[torch.Tensor] = None,
260
261
262
263
264
265
266
267
268
269
270
271
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

272
        for layer in islice(self.layers, self.start_layer, self.end_layer):
273
            hidden_states = layer(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
274

275
276
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
277
278
279
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
280
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
281
282
283
        return hidden_states


284
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
285
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
286

287
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Zhuohan Li's avatar
Zhuohan Li committed
288
        super().__init__()
289
290
291
292
293

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

294
295
296
297
        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
298
299
300
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
301

302
303
304
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
305
306
    def forward(
        self,
307
308
        input_ids: torch.Tensor,
        positions: torch.Tensor,
309
        intermediate_tensors: Optional[IntermediateTensors],
310
        inputs_embeds: Optional[torch.Tensor] = None,
311
    ) -> Union[torch.Tensor, IntermediateTensors]:
312
313
        return self.decoder(input_ids,
                            positions,
314
                            intermediate_tensors,
315
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
316

317
318
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
319
320
321
322
323
324
325
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
326
        loaded_params: set[str] = set()
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
354

355
class OPTForCausalLM(nn.Module, SupportsPP):
356
357
358
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
359
360
    }

361
362
363
364
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "decoder.": "model.decoder.",
    })

365
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
366
367
368
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Zhuohan Li's avatar
Zhuohan Li committed
369
        self.config = config
370
        self.quant_config = quant_config
371
        self.model = OPTModel(vllm_config=vllm_config,
372
                              prefix=maybe_prefix(prefix, "model"))
373
374
375
376
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
377
378
379
                                          config.word_embed_proj_dim,
                                          prefix=maybe_prefix(
                                              prefix, "lm_head"))
380
        self.logits_processor = LogitsProcessor(config.vocab_size)
381
382
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
383

384
385
386
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
387
388
    def forward(
        self,
389
390
        input_ids: torch.Tensor,
        positions: torch.Tensor,
391
        intermediate_tensors: Optional[IntermediateTensors] = None,
392
        inputs_embeds: Optional[torch.Tensor] = None,
393
    ) -> Union[torch.Tensor, IntermediateTensors]:
394
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
395
                                   inputs_embeds)
396
397
        return hidden_states

398
399
400
401
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
402
        logits = self.logits_processor(self.lm_head, hidden_states)
403
404
        return logits

405
406
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
407
408
409
410
411
412
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head.weight"]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)