opt.py 16.2 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.compilation.decorators import support_torch_compile
28
from vllm.config import CacheConfig
29
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.model_executor.layers.activation import get_act_fn
31
32
33
34
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
45
46
47
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
48
49
50
51

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
52
53
54
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55
56
57
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

58
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
59
60
61
62
63
64
65
66
67
68
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
69
        cache_config: Optional[CacheConfig] = None,
70
        quant_config: Optional[QuantizationConfig] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
71
72
73
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
74
75
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
76
77
78
79
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
80
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
81

82
        self.qkv_proj = QKVParallelLinear(
83
            embed_dim,
84
85
            self.head_dim,
            total_num_heads,
86
            bias=bias,
87
            quant_config=quant_config,
88
89
90
91
92
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
93
            quant_config=quant_config,
94
        )
95
96
        self.attn = Attention(self.num_heads,
                              self.head_dim,
97
                              scale=self.scaling,
98
99
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
103

    def forward(
        self,
        hidden_states: torch.Tensor,
104
105
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
106
    ) -> torch.Tensor:
107
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        q, k, v = qkv.chunk(chunks=3, dim=-1)
109
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
110
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
111
112
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
113

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
114
115
class OPTDecoderLayer(nn.Module):

116
117
118
    def __init__(
        self,
        config: OPTConfig,
119
        cache_config: Optional[CacheConfig] = None,
120
        quant_config: Optional[QuantizationConfig] = None,
121
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
122
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
123
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
124
125
126
127
128
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
129
            cache_config=cache_config,
130
            quant_config=quant_config,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
131
132
133
134
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
135
136
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
137
138
139
140
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
141
            quant_config=quant_config,
142
        )
143
144
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
145
146
147
148
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
149
            quant_config=quant_config,
150
        )
Zhuohan Li's avatar
Zhuohan Li committed
151
        self.final_layer_norm = nn.LayerNorm(
152
153
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
154

Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157
    def forward(
        self,
        hidden_states: torch.Tensor,
158
159
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
160
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
161
162
163
164
165
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
166
167
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
168
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
169
170
171
172
173
174
175
176
177
178
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
179
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
180
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
181
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
182
183
184
185
186
187
188
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
189
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
190

191
192
193
    def __init__(
        self,
        config: OPTConfig,
194
        cache_config: Optional[CacheConfig] = None,
195
        quant_config: Optional[QuantizationConfig] = None,
196
        prefix: str = "",
197
    ):
Zhuohan Li's avatar
Zhuohan Li committed
198
199
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
200
201
202
203
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

204
205
206
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
207
        )
Zhuohan Li's avatar
Zhuohan Li committed
208
209
210
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
211

Zhuohan Li's avatar
Zhuohan Li committed
212
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
213
        if config.word_embed_proj_dim != config.hidden_size:
214
215
216
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
217
                                                quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
218
219
220
221
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
222
223
224
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
225
                                               quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
226
227
228
        else:
            self.project_in = None

229
230
231
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
232
233
234
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
235
236
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
237
238
239
        else:
            self.final_layer_norm = None

240
241
242
243
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: OPTDecoderLayer(config, cache_config, quant_config),
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
244

245
246
247
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
248
249
    def forward(
        self,
250
251
        input_ids: torch.Tensor,
        positions: torch.Tensor,
252
253
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
254
        intermediate_tensors: Optional[IntermediateTensors],
255
        inputs_embeds: Optional[torch.Tensor] = None,
256
257
258
259
260
261
262
263
264
265
266
267
268
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
269
            layer = self.layers[i]
270
271
272
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
273

274
275
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
276
277
278
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
279
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
280
281
282
        return hidden_states


283
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
284
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
285

286
287
288
    def __init__(
        self,
        config: OPTConfig,
289
        cache_config: Optional[CacheConfig] = None,
290
        quant_config: Optional[QuantizationConfig] = None,
291
    ):
Zhuohan Li's avatar
Zhuohan Li committed
292
        super().__init__()
293
        self.decoder = OPTDecoder(config, cache_config, quant_config)
294
295
296
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
297

298
299
300
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
301
302
    def forward(
        self,
303
304
        input_ids: torch.Tensor,
        positions: torch.Tensor,
305
306
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
307
        intermediate_tensors: Optional[IntermediateTensors],
308
        inputs_embeds: Optional[torch.Tensor] = None,
309
    ) -> Union[torch.Tensor, IntermediateTensors]:
310
311
312
313
        return self.decoder(input_ids,
                            positions,
                            kv_caches,
                            attn_metadata,
314
                            intermediate_tensors,
315
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
316
317


318
class OPTForCausalLM(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
319

320
321
322
323
324
325
326
327
328
329
330
331
332
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
    }
    default_bitsandbytes_target_modules = [
        ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
    ]
    # in TP, these weights are partitioned along the column dimension (dim=-1)
    column_parallel_weights_modules = [".out_proj.", ".fc2."]

333
334
    def __init__(
        self,
335
        config: OPTConfig,
336
        cache_config: Optional[CacheConfig] = None,
337
        quant_config: Optional[QuantizationConfig] = None,
338
    ):
Zhuohan Li's avatar
Zhuohan Li committed
339
340
        super().__init__()
        self.config = config
341
        self.quant_config = quant_config
342
        self.model = OPTModel(config, cache_config, quant_config)
343
344
345
346
347
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
348
349
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
350
351
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
352
353
354

    def forward(
        self,
355
356
        input_ids: torch.Tensor,
        positions: torch.Tensor,
357
358
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
359
        intermediate_tensors: Optional[IntermediateTensors] = None,
360
    ) -> Union[torch.Tensor, IntermediateTensors]:
361
        hidden_states = self.model(input_ids, positions, kv_caches,
362
                                   attn_metadata, intermediate_tensors)
363
364
        return hidden_states

365
366
367
368
369
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
370
        logits = self.logits_processor(self.lm_head, hidden_states,
371
372
373
                                       sampling_metadata)
        return logits

374
375
    def sample(
        self,
376
        logits: torch.Tensor,
377
        sampling_metadata: SamplingMetadata,
378
    ) -> Optional[SamplerOutput]:
379
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
380
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
381

382
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
383
384
385
386
387
388
389
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
390
        for name, loaded_weight in weights:
391
            if "lm_head.weight" in name and self.config.tie_word_embeddings:
Zhuohan Li's avatar
Zhuohan Li committed
392
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
395
            if name.startswith("decoder."):
                name = "model." + name

396
397
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
398
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
399
400
401
402
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
403
404
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
405
                param = params_dict[name]
406
407
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
408
                break
409
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
410
411
412
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
413
414
                if is_pp_missing_parameter(name, self):
                    continue
415
416
417
418
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)