opt.py 16.9 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only OPT model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
20

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
21
22
23
24
import torch
from torch import nn
from transformers import OPTConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.compilation.decorators import support_torch_compile
27
from vllm.config import CacheConfig
28
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
45
46
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
47

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
48
49
50
51

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
52
53
54
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55
56
57
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

58
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
59
60
61
62
63
64
65
66
67
68
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
69
        cache_config: Optional[CacheConfig] = None,
70
        quant_config: Optional[QuantizationConfig] = None,
71
        prefix: str = "",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
72
73
74
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
75
76
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
77
78
79
80
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
81
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
82

83
        self.qkv_proj = QKVParallelLinear(
84
            embed_dim,
85
86
            self.head_dim,
            total_num_heads,
87
            bias=bias,
88
            quant_config=quant_config,
89
            prefix=f"{prefix}.qkv_proj",
90
91
92
93
94
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
95
            quant_config=quant_config,
96
            prefix=f"{prefix}.out_proj",
97
        )
98
99
        self.attn = Attention(self.num_heads,
                              self.head_dim,
100
                              scale=self.scaling,
101
                              cache_config=cache_config,
102
103
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
107

    def forward(
        self,
        hidden_states: torch.Tensor,
108
109
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
110
    ) -> torch.Tensor:
111
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        q, k, v = qkv.chunk(chunks=3, dim=-1)
113
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
114
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
115
116
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
117

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
118
119
class OPTDecoderLayer(nn.Module):

120
121
122
    def __init__(
        self,
        config: OPTConfig,
123
        cache_config: Optional[CacheConfig] = None,
124
        quant_config: Optional[QuantizationConfig] = None,
125
        prefix: str = "",
126
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
127
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
128
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
129
130
131
132
133
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
134
            cache_config=cache_config,
135
            quant_config=quant_config,
136
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
137
138
139
140
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
141
142
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
143
144
145
146
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
147
            quant_config=quant_config,
148
            prefix=f"{prefix}.fc1",
149
        )
150
151
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
152
153
154
155
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
156
            quant_config=quant_config,
157
            prefix=f"{prefix}.fc2",
158
        )
Zhuohan Li's avatar
Zhuohan Li committed
159
        self.final_layer_norm = nn.LayerNorm(
160
161
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
162

Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
    def forward(
        self,
        hidden_states: torch.Tensor,
166
167
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
168
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
169
170
171
172
173
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
174
175
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
176
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
177
178
179
180
181
182
183
184
185
186
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
187
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
188
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
189
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
190
191
192
193
194
195
196
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
197
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
198

199
200
201
    def __init__(
        self,
        config: OPTConfig,
202
        cache_config: Optional[CacheConfig] = None,
203
        quant_config: Optional[QuantizationConfig] = None,
204
        prefix: str = "",
205
    ):
Zhuohan Li's avatar
Zhuohan Li committed
206
207
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
208
209
210
211
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

212
213
214
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
215
        )
Zhuohan Li's avatar
Zhuohan Li committed
216
217
218
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
219

Zhuohan Li's avatar
Zhuohan Li committed
220
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
221
        if config.word_embed_proj_dim != config.hidden_size:
222
223
224
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
225
226
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
227
228
229
230
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
231
232
233
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
234
235
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
236
237
238
        else:
            self.project_in = None

239
240
241
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
242
243
244
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
245
246
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
247
248
249
        else:
            self.final_layer_norm = None

250
251
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
252
253
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
254
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
255

256
257
258
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
259
260
    def forward(
        self,
261
262
        input_ids: torch.Tensor,
        positions: torch.Tensor,
263
264
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
265
        intermediate_tensors: Optional[IntermediateTensors],
266
        inputs_embeds: Optional[torch.Tensor] = None,
267
268
269
270
271
272
273
274
275
276
277
278
279
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
280
            layer = self.layers[i]
281
282
283
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
284

285
286
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
287
288
289
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
290
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
291
292
293
        return hidden_states


294
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
295
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
296

297
298
299
    def __init__(
        self,
        config: OPTConfig,
300
        cache_config: Optional[CacheConfig] = None,
301
        quant_config: Optional[QuantizationConfig] = None,
302
        prefix: str = "",
303
    ):
Zhuohan Li's avatar
Zhuohan Li committed
304
        super().__init__()
305
306
307
308
        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
309
310
311
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
312

313
314
315
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
316
317
    def forward(
        self,
318
319
        input_ids: torch.Tensor,
        positions: torch.Tensor,
320
321
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
322
        intermediate_tensors: Optional[IntermediateTensors],
323
        inputs_embeds: Optional[torch.Tensor] = None,
324
    ) -> Union[torch.Tensor, IntermediateTensors]:
325
326
327
328
        return self.decoder(input_ids,
                            positions,
                            kv_caches,
                            attn_metadata,
329
                            intermediate_tensors,
330
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
331
332


333
class OPTForCausalLM(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
334

335
336
337
338
339
340
341
342
343
344
345
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
    }
    default_bitsandbytes_target_modules = [
        ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
    ]

346
347
    def __init__(
        self,
348
        config: OPTConfig,
349
        cache_config: Optional[CacheConfig] = None,
350
        quant_config: Optional[QuantizationConfig] = None,
351
        prefix: str = "",
352
    ):
Zhuohan Li's avatar
Zhuohan Li committed
353
354
        super().__init__()
        self.config = config
355
        self.quant_config = quant_config
356
357
358
359
        self.model = OPTModel(config,
                              cache_config,
                              quant_config,
                              prefix=maybe_prefix(prefix, "model"))
360
361
362
363
364
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
365
366
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
367
368
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
369
370
371

    def forward(
        self,
372
373
        input_ids: torch.Tensor,
        positions: torch.Tensor,
374
375
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
376
        intermediate_tensors: Optional[IntermediateTensors] = None,
377
    ) -> Union[torch.Tensor, IntermediateTensors]:
378
        hidden_states = self.model(input_ids, positions, kv_caches,
379
                                   attn_metadata, intermediate_tensors)
380
381
        return hidden_states

382
383
384
385
386
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
387
        logits = self.logits_processor(self.lm_head, hidden_states,
388
389
390
                                       sampling_metadata)
        return logits

391
392
    def sample(
        self,
393
        logits: torch.Tensor,
394
        sampling_metadata: SamplingMetadata,
395
    ) -> Optional[SamplerOutput]:
396
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
397
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
398

399
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
400
401
402
403
404
405
406
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
407
        for name, loaded_weight in weights:
408
            if "lm_head.weight" in name and self.config.tie_word_embeddings:
Zhuohan Li's avatar
Zhuohan Li committed
409
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411
412
            if name.startswith("decoder."):
                name = "model." + name

413
414
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
415
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
416
417
418
419
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
420
421
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
422
                param = params_dict[name]
423
424
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
425
                break
426
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
427
428
429
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
430
431
                if is_pp_missing_parameter(name, self):
                    continue
432
433
434
435
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)