opt.py 16.9 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only OPT model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
20

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
21
22
23
24
import torch
from torch import nn
from transformers import OPTConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.compilation.decorators import support_torch_compile
27
from vllm.config import CacheConfig, VllmConfig
28
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
36
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
45
46
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
47

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
48
49
50
51

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
52
53
54
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55
56
57
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

58
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
59
60
61
62
63
64
65
66
67
68
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
69
        cache_config: Optional[CacheConfig] = None,
70
        quant_config: Optional[QuantizationConfig] = None,
71
        prefix: str = "",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
72
73
74
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
75
76
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
77
78
79
80
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
81
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
82

83
        self.qkv_proj = QKVParallelLinear(
84
            embed_dim,
85
86
            self.head_dim,
            total_num_heads,
87
            bias=bias,
88
            quant_config=quant_config,
89
            prefix=f"{prefix}.qkv_proj",
90
91
92
93
94
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
95
            quant_config=quant_config,
96
            prefix=f"{prefix}.out_proj",
97
        )
98
99
        self.attn = Attention(self.num_heads,
                              self.head_dim,
100
                              scale=self.scaling,
101
                              cache_config=cache_config,
102
103
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
107

    def forward(
        self,
        hidden_states: torch.Tensor,
108
109
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
110
    ) -> torch.Tensor:
111
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        q, k, v = qkv.chunk(chunks=3, dim=-1)
113
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
114
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
115
116
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
117

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
118
119
class OPTDecoderLayer(nn.Module):

120
121
122
    def __init__(
        self,
        config: OPTConfig,
123
        cache_config: Optional[CacheConfig] = None,
124
        quant_config: Optional[QuantizationConfig] = None,
125
        prefix: str = "",
126
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
127
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
128
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
129
130
131
132
133
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
134
            cache_config=cache_config,
135
            quant_config=quant_config,
136
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
137
138
139
140
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
141
142
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
143
144
145
146
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
147
            quant_config=quant_config,
148
            prefix=f"{prefix}.fc1",
149
        )
150
        self.activation_fn = get_act_fn(config.activation_function)
151
152
153
154
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
155
            quant_config=quant_config,
156
            prefix=f"{prefix}.fc2",
157
        )
Zhuohan Li's avatar
Zhuohan Li committed
158
        self.final_layer_norm = nn.LayerNorm(
159
160
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
161

Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
    def forward(
        self,
        hidden_states: torch.Tensor,
165
166
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
167
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
168
169
170
171
172
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
173
174
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
175
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
176
177
178
179
180
181
182
183
184
185
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
186
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
187
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
188
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
189
190
191
192
193
194
195
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
196
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
197

198
199
200
    def __init__(
        self,
        config: OPTConfig,
201
        cache_config: Optional[CacheConfig] = None,
202
        quant_config: Optional[QuantizationConfig] = None,
203
        prefix: str = "",
204
    ):
Zhuohan Li's avatar
Zhuohan Li committed
205
206
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
207
208
209
210
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

211
212
213
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
214
        )
Zhuohan Li's avatar
Zhuohan Li committed
215
216
217
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
218

Zhuohan Li's avatar
Zhuohan Li committed
219
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
220
        if config.word_embed_proj_dim != config.hidden_size:
221
222
223
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
224
225
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
226
227
228
229
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
230
231
232
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
233
234
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
235
236
237
        else:
            self.project_in = None

238
239
240
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
241
242
243
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
244
245
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
246
247
248
        else:
            self.final_layer_norm = None

249
250
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
251
252
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
253
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
254

255
256
257
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
258
259
    def forward(
        self,
260
261
        input_ids: torch.Tensor,
        positions: torch.Tensor,
262
263
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
264
        intermediate_tensors: Optional[IntermediateTensors],
265
        inputs_embeds: Optional[torch.Tensor] = None,
266
267
268
269
270
271
272
273
274
275
276
277
278
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
279
            layer = self.layers[i]
280
281
282
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
283

284
285
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
286
287
288
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
289
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
290
291
292
        return hidden_states


293
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
294
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
295

296
297
298
    def __init__(
        self,
        config: OPTConfig,
299
        cache_config: Optional[CacheConfig] = None,
300
        quant_config: Optional[QuantizationConfig] = None,
301
        prefix: str = "",
302
    ):
Zhuohan Li's avatar
Zhuohan Li committed
303
        super().__init__()
304
305
306
307
        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
308
309
310
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
311

312
313
314
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
315
316
    def forward(
        self,
317
318
        input_ids: torch.Tensor,
        positions: torch.Tensor,
319
320
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
321
        intermediate_tensors: Optional[IntermediateTensors],
322
        inputs_embeds: Optional[torch.Tensor] = None,
323
    ) -> Union[torch.Tensor, IntermediateTensors]:
324
325
326
327
        return self.decoder(input_ids,
                            positions,
                            kv_caches,
                            attn_metadata,
328
                            intermediate_tensors,
329
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
330
331


332
class OPTForCausalLM(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
333

334
335
336
337
338
339
340
341
342
343
344
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
    }
    default_bitsandbytes_target_modules = [
        ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
    ]

345
346
    def __init__(
        self,
347
        vllm_config: VllmConfig,
348
        prefix: str = "",
349
350
351
352
353
    ) -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
Zhuohan Li's avatar
Zhuohan Li committed
354
355
        super().__init__()
        self.config = config
356
        self.quant_config = quant_config
357
358
359
360
        self.model = OPTModel(config,
                              cache_config,
                              quant_config,
                              prefix=maybe_prefix(prefix, "model"))
361
362
363
364
365
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
366
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
367
        self.sampler = get_sampler()
368
369
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
370
371
372

    def forward(
        self,
373
374
        input_ids: torch.Tensor,
        positions: torch.Tensor,
375
376
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
377
        intermediate_tensors: Optional[IntermediateTensors] = None,
378
    ) -> Union[torch.Tensor, IntermediateTensors]:
379
        hidden_states = self.model(input_ids, positions, kv_caches,
380
                                   attn_metadata, intermediate_tensors)
381
382
        return hidden_states

383
384
385
386
387
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
388
        logits = self.logits_processor(self.lm_head, hidden_states,
389
390
391
                                       sampling_metadata)
        return logits

392
393
    def sample(
        self,
394
        logits: torch.Tensor,
395
        sampling_metadata: SamplingMetadata,
396
    ) -> Optional[SamplerOutput]:
397
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
398
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
399

400
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
401
402
403
404
405
406
407
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
408
        for name, loaded_weight in weights:
409
            if "lm_head.weight" in name and self.config.tie_word_embeddings:
Zhuohan Li's avatar
Zhuohan Li committed
410
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
411
412
413
            if name.startswith("decoder."):
                name = "model." + name

414
415
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
416
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
417
418
419
420
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
421
422
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
423
                param = params_dict[name]
424
425
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
426
                break
427
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
428
429
430
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
431
432
                if is_pp_missing_parameter(name, self):
                    continue
433
434
435
436
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)