opt.py 16.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.compilation.decorators import support_torch_compile
28
from vllm.config import CacheConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
34
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
45
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
46
47
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
48

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
49
50
51
52

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
53
54
55
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
56
57
58
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

59
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
70
        cache_config: Optional[CacheConfig] = None,
71
        quant_config: Optional[QuantizationConfig] = None,
72
        prefix: str = "",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
73
74
75
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
76
77
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
78
79
80
81
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
82
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
83

84
        self.qkv_proj = QKVParallelLinear(
85
            embed_dim,
86
87
            self.head_dim,
            total_num_heads,
88
            bias=bias,
89
            quant_config=quant_config,
90
            prefix=f"{prefix}.qkv_proj",
91
92
93
94
95
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
96
            quant_config=quant_config,
97
            prefix=f"{prefix}.out_proj",
98
        )
99
100
        self.attn = Attention(self.num_heads,
                              self.head_dim,
101
                              scale=self.scaling,
102
                              cache_config=cache_config,
103
104
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108

    def forward(
        self,
        hidden_states: torch.Tensor,
109
110
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
111
    ) -> torch.Tensor:
112
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
        q, k, v = qkv.chunk(chunks=3, dim=-1)
114
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
115
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
116
117
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
118

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
119
120
class OPTDecoderLayer(nn.Module):

121
122
123
    def __init__(
        self,
        config: OPTConfig,
124
        cache_config: Optional[CacheConfig] = None,
125
        quant_config: Optional[QuantizationConfig] = None,
126
        prefix: str = "",
127
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
128
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
129
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
130
131
132
133
134
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
135
            cache_config=cache_config,
136
            quant_config=quant_config,
137
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
138
139
140
141
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
142
143
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
144
145
146
147
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
148
            quant_config=quant_config,
149
            prefix=f"{prefix}.fc1",
150
        )
151
152
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
153
154
155
156
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
157
            quant_config=quant_config,
158
            prefix=f"{prefix}.fc2",
159
        )
Zhuohan Li's avatar
Zhuohan Li committed
160
        self.final_layer_norm = nn.LayerNorm(
161
162
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
163

Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166
    def forward(
        self,
        hidden_states: torch.Tensor,
167
168
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
169
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
170
171
172
173
174
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
175
176
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
177
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
178
179
180
181
182
183
184
185
186
187
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
188
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
189
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
190
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
191
192
193
194
195
196
197
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
198
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
199

200
201
202
    def __init__(
        self,
        config: OPTConfig,
203
        cache_config: Optional[CacheConfig] = None,
204
        quant_config: Optional[QuantizationConfig] = None,
205
        prefix: str = "",
206
    ):
Zhuohan Li's avatar
Zhuohan Li committed
207
208
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
209
210
211
212
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

213
214
215
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
216
        )
Zhuohan Li's avatar
Zhuohan Li committed
217
218
219
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
220

Zhuohan Li's avatar
Zhuohan Li committed
221
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
222
        if config.word_embed_proj_dim != config.hidden_size:
223
224
225
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
226
227
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
228
229
230
231
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
232
233
234
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
235
236
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
237
238
239
        else:
            self.project_in = None

240
241
242
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
243
244
245
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
246
247
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
248
249
250
        else:
            self.final_layer_norm = None

251
252
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
253
254
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
255
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
256

257
258
259
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
260
261
    def forward(
        self,
262
263
        input_ids: torch.Tensor,
        positions: torch.Tensor,
264
265
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
266
        intermediate_tensors: Optional[IntermediateTensors],
267
        inputs_embeds: Optional[torch.Tensor] = None,
268
269
270
271
272
273
274
275
276
277
278
279
280
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
281
            layer = self.layers[i]
282
283
284
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
285

286
287
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
288
289
290
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
291
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
292
293
294
        return hidden_states


295
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
296
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
297

298
299
300
    def __init__(
        self,
        config: OPTConfig,
301
        cache_config: Optional[CacheConfig] = None,
302
        quant_config: Optional[QuantizationConfig] = None,
303
        prefix: str = "",
304
    ):
Zhuohan Li's avatar
Zhuohan Li committed
305
        super().__init__()
306
307
308
309
        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
310
311
312
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
313

314
315
316
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
317
318
    def forward(
        self,
319
320
        input_ids: torch.Tensor,
        positions: torch.Tensor,
321
322
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
323
        intermediate_tensors: Optional[IntermediateTensors],
324
        inputs_embeds: Optional[torch.Tensor] = None,
325
    ) -> Union[torch.Tensor, IntermediateTensors]:
326
327
328
329
        return self.decoder(input_ids,
                            positions,
                            kv_caches,
                            attn_metadata,
330
                            intermediate_tensors,
331
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
332
333


334
class OPTForCausalLM(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
335

336
337
338
339
340
341
342
343
344
345
346
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
    }
    default_bitsandbytes_target_modules = [
        ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
    ]

347
348
    def __init__(
        self,
349
        config: OPTConfig,
350
        cache_config: Optional[CacheConfig] = None,
351
        quant_config: Optional[QuantizationConfig] = None,
352
        prefix: str = "",
353
    ):
Zhuohan Li's avatar
Zhuohan Li committed
354
355
        super().__init__()
        self.config = config
356
        self.quant_config = quant_config
357
358
359
360
        self.model = OPTModel(config,
                              cache_config,
                              quant_config,
                              prefix=maybe_prefix(prefix, "model"))
361
362
363
364
365
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
366
367
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
368
369
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
370
371
372

    def forward(
        self,
373
374
        input_ids: torch.Tensor,
        positions: torch.Tensor,
375
376
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
377
        intermediate_tensors: Optional[IntermediateTensors] = None,
378
    ) -> Union[torch.Tensor, IntermediateTensors]:
379
        hidden_states = self.model(input_ids, positions, kv_caches,
380
                                   attn_metadata, intermediate_tensors)
381
382
        return hidden_states

383
384
385
386
387
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
388
        logits = self.logits_processor(self.lm_head, hidden_states,
389
390
391
                                       sampling_metadata)
        return logits

392
393
    def sample(
        self,
394
        logits: torch.Tensor,
395
        sampling_metadata: SamplingMetadata,
396
    ) -> Optional[SamplerOutput]:
397
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
398
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
399

400
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
401
402
403
404
405
406
407
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
408
        for name, loaded_weight in weights:
409
            if "lm_head.weight" in name and self.config.tie_word_embeddings:
Zhuohan Li's avatar
Zhuohan Li committed
410
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
411
412
413
            if name.startswith("decoder."):
                name = "model." + name

414
415
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
416
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
417
418
419
420
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
421
422
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
423
                param = params_dict[name]
424
425
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
426
                break
427
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
428
429
430
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
431
432
                if is_pp_missing_parameter(name, self):
                    continue
433
434
435
436
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)