opt.py 16.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
6
# Copyright 2023 The vLLM team.
7
8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
21
"""Inference-only OPT model compatible with HuggingFace weights."""
22
from collections.abc import Iterable
23
from itertools import islice
24
from typing import Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
26
27
28
29
import torch
from torch import nn
from transformers import OPTConfig

30
from vllm.attention import Attention
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import get_act_fn
35
36
37
38
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    ParallelLMHead, VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
46

47
from .interfaces import SupportsPP
48
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
49
50
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
51

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
52
53
54
55

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
56
57
58
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
59
60
61
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

62
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
63
64
65
66
67
68
69
70
71
72
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
73
        cache_config: Optional[CacheConfig] = None,
74
        quant_config: Optional[QuantizationConfig] = None,
75
        prefix: str = "",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
76
77
78
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
79
80
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
81
82
83
84
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
85
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
86

87
        self.qkv_proj = QKVParallelLinear(
88
            embed_dim,
89
90
            self.head_dim,
            total_num_heads,
91
            bias=bias,
92
            quant_config=quant_config,
93
            prefix=f"{prefix}.qkv_proj",
94
95
96
97
98
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
99
            quant_config=quant_config,
100
            prefix=f"{prefix}.out_proj",
101
        )
102
103
        self.attn = Attention(self.num_heads,
                              self.head_dim,
104
                              scale=self.scaling,
105
                              cache_config=cache_config,
106
107
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
113
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        q, k, v = qkv.chunk(chunks=3, dim=-1)
115
        attn_output = self.attn(q, k, v)
Zhuohan Li's avatar
Zhuohan Li committed
116
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
117
118
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
119

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
120
121
class OPTDecoderLayer(nn.Module):

122
123
124
    def __init__(
        self,
        config: OPTConfig,
125
        cache_config: Optional[CacheConfig] = None,
126
        quant_config: Optional[QuantizationConfig] = None,
127
        prefix: str = "",
128
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
129
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
130
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
131
132
133
134
135
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
136
            cache_config=cache_config,
137
            quant_config=quant_config,
138
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
139
140
141
142
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
143
144
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
145
146
147
148
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
149
            quant_config=quant_config,
150
            prefix=f"{prefix}.fc1",
151
        )
152
        self.activation_fn = get_act_fn(config.activation_function)
153
154
155
156
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
157
            quant_config=quant_config,
158
            prefix=f"{prefix}.fc2",
159
        )
Zhuohan Li's avatar
Zhuohan Li committed
160
        self.final_layer_norm = nn.LayerNorm(
161
162
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
163

Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166
167
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
168
169
170
171
172
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
173
        hidden_states = self.self_attn(hidden_states=hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
174
175
176
177
178
179
180
181
182
183
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
184
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
185
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
186
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
187
188
189
190
191
192
193
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
194
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
195

196
197
198
    def __init__(
        self,
        config: OPTConfig,
199
        cache_config: Optional[CacheConfig] = None,
200
        quant_config: Optional[QuantizationConfig] = None,
201
        prefix: str = "",
202
    ):
Zhuohan Li's avatar
Zhuohan Li committed
203
204
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
205
206
207
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

208
209
210
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
211
        )
Zhuohan Li's avatar
Zhuohan Li committed
212
213
214
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
215

Zhuohan Li's avatar
Zhuohan Li committed
216
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
217
        if config.word_embed_proj_dim != config.hidden_size:
218
219
220
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
221
222
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
223
224
225
226
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
227
228
229
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
230
231
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
232
233
234
        else:
            self.project_in = None

235
236
237
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
238
239
240
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
241
242
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
243
244
245
        else:
            self.final_layer_norm = None

246
247
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
248
249
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
250
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
251

252
253
254
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
255
256
    def forward(
        self,
257
258
        input_ids: torch.Tensor,
        positions: torch.Tensor,
259
        intermediate_tensors: Optional[IntermediateTensors],
260
        inputs_embeds: Optional[torch.Tensor] = None,
261
262
263
264
265
266
267
268
269
270
271
272
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

273
        for layer in islice(self.layers, self.start_layer, self.end_layer):
274
            hidden_states = layer(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
275

276
277
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
278
279
280
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
281
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
282
283
284
        return hidden_states


285
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
286
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
287

288
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Zhuohan Li's avatar
Zhuohan Li committed
289
        super().__init__()
290
291
292
293
294

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

295
296
297
298
        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
299
300
301
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
302

303
304
305
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
306
307
    def forward(
        self,
308
309
        input_ids: torch.Tensor,
        positions: torch.Tensor,
310
        intermediate_tensors: Optional[IntermediateTensors],
311
        inputs_embeds: Optional[torch.Tensor] = None,
312
    ) -> Union[torch.Tensor, IntermediateTensors]:
313
314
        return self.decoder(input_ids,
                            positions,
315
                            intermediate_tensors,
316
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
317

318
319
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
320
321
322
323
324
325
326
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
327
        loaded_params: set[str] = set()
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
355

356
class OPTForCausalLM(nn.Module, SupportsPP):
357
358
359
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
360
361
    }

362
363
364
365
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "decoder.": "model.decoder.",
    })

366
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
367
368
369
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Zhuohan Li's avatar
Zhuohan Li committed
370
        self.config = config
371
        self.quant_config = quant_config
372
        self.model = OPTModel(vllm_config=vllm_config,
373
                              prefix=maybe_prefix(prefix, "model"))
374
375
376
377
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
378
379
380
                                          config.word_embed_proj_dim,
                                          prefix=maybe_prefix(
                                              prefix, "lm_head"))
381
        self.logits_processor = LogitsProcessor(config.vocab_size)
382
383
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
384

385
386
387
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
388
389
    def forward(
        self,
390
391
        input_ids: torch.Tensor,
        positions: torch.Tensor,
392
        intermediate_tensors: Optional[IntermediateTensors] = None,
393
        inputs_embeds: Optional[torch.Tensor] = None,
394
    ) -> Union[torch.Tensor, IntermediateTensors]:
395
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
396
                                   inputs_embeds)
397
398
        return hidden_states

399
400
401
402
403
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
404
        logits = self.logits_processor(self.lm_head, hidden_states,
405
406
407
                                       sampling_metadata)
        return logits

408
409
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
410
411
412
413
414
415
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head.weight"]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)