opt.py 17 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only OPT model compatible with HuggingFace weights."""
19
from typing import Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
20

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
21
22
23
24
import torch
from torch import nn
from transformers import OPTConfig

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.compilation.decorators import support_torch_compile
27
from vllm.config import CacheConfig, VllmConfig
28
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
36
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
45
46
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
47

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
48
49
50
51

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
52
53
54
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
55
56
57
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

58
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
59
60
61
62
63
64
65
66
67
68
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
69
        cache_config: Optional[CacheConfig] = None,
70
        quant_config: Optional[QuantizationConfig] = None,
71
        prefix: str = "",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
72
73
74
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
75
76
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
77
78
79
80
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
81
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
82

83
        self.qkv_proj = QKVParallelLinear(
84
            embed_dim,
85
86
            self.head_dim,
            total_num_heads,
87
            bias=bias,
88
            quant_config=quant_config,
89
            prefix=f"{prefix}.qkv_proj",
90
91
92
93
94
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
95
            quant_config=quant_config,
96
            prefix=f"{prefix}.out_proj",
97
        )
98
99
        self.attn = Attention(self.num_heads,
                              self.head_dim,
100
                              scale=self.scaling,
101
                              cache_config=cache_config,
102
103
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
107

    def forward(
        self,
        hidden_states: torch.Tensor,
108
109
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
110
    ) -> torch.Tensor:
111
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        q, k, v = qkv.chunk(chunks=3, dim=-1)
113
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
114
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
115
116
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
117

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
118
119
class OPTDecoderLayer(nn.Module):

120
121
122
    def __init__(
        self,
        config: OPTConfig,
123
        cache_config: Optional[CacheConfig] = None,
124
        quant_config: Optional[QuantizationConfig] = None,
125
        prefix: str = "",
126
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
127
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
128
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
129
130
131
132
133
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
134
            cache_config=cache_config,
135
            quant_config=quant_config,
136
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
137
138
139
140
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
141
142
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
143
144
145
146
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
147
            quant_config=quant_config,
148
            prefix=f"{prefix}.fc1",
149
        )
150
        self.activation_fn = get_act_fn(config.activation_function)
151
152
153
154
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
155
            quant_config=quant_config,
156
            prefix=f"{prefix}.fc2",
157
        )
Zhuohan Li's avatar
Zhuohan Li committed
158
        self.final_layer_norm = nn.LayerNorm(
159
160
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
161

Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
    def forward(
        self,
        hidden_states: torch.Tensor,
165
166
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
167
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
168
169
170
171
172
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
173
174
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
175
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
176
177
178
179
180
181
182
183
184
185
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
186
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
187
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
188
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
189
190
191
192
193
194
195
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
196
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
197

198
199
200
    def __init__(
        self,
        config: OPTConfig,
201
        cache_config: Optional[CacheConfig] = None,
202
        quant_config: Optional[QuantizationConfig] = None,
203
        prefix: str = "",
204
    ):
Zhuohan Li's avatar
Zhuohan Li committed
205
206
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
207
208
209
210
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

211
212
213
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
214
        )
Zhuohan Li's avatar
Zhuohan Li committed
215
216
217
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
218

Zhuohan Li's avatar
Zhuohan Li committed
219
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
220
        if config.word_embed_proj_dim != config.hidden_size:
221
222
223
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
224
225
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
226
227
228
229
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
230
231
232
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
233
234
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
235
236
237
        else:
            self.project_in = None

238
239
240
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
241
242
243
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
244
245
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
246
247
248
        else:
            self.final_layer_norm = None

249
250
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
251
252
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
253
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
254

255
256
257
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
258
259
    def forward(
        self,
260
261
        input_ids: torch.Tensor,
        positions: torch.Tensor,
262
263
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
264
        intermediate_tensors: Optional[IntermediateTensors],
265
        inputs_embeds: Optional[torch.Tensor] = None,
266
267
268
269
270
271
272
273
274
275
276
277
278
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
279
            layer = self.layers[i]
280
281
282
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
283

284
285
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
286
287
288
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
289
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
290
291
292
        return hidden_states


293
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
294
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
295

296
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Zhuohan Li's avatar
Zhuohan Li committed
297
        super().__init__()
298
299
300
301
302

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

303
304
305
306
        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
307
308
309
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
310

311
312
313
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
314
315
    def forward(
        self,
316
317
        input_ids: torch.Tensor,
        positions: torch.Tensor,
318
319
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
320
        intermediate_tensors: Optional[IntermediateTensors],
321
        inputs_embeds: Optional[torch.Tensor] = None,
322
    ) -> Union[torch.Tensor, IntermediateTensors]:
323
324
325
326
        return self.decoder(input_ids,
                            positions,
                            kv_caches,
                            attn_metadata,
327
                            intermediate_tensors,
328
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
329
330


331
class OPTForCausalLM(nn.Module, SupportsPP):
332
333
334
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
335
336
    }

337
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
338
339
340
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Zhuohan Li's avatar
Zhuohan Li committed
341
342
        super().__init__()
        self.config = config
343
        self.quant_config = quant_config
344
        self.model = OPTModel(vllm_config=vllm_config,
345
                              prefix=maybe_prefix(prefix, "model"))
346
347
348
349
350
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
351
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
352
        self.sampler = get_sampler()
353
354
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
355

356
357
358
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
359
360
    def forward(
        self,
361
362
        input_ids: torch.Tensor,
        positions: torch.Tensor,
363
364
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
365
        intermediate_tensors: Optional[IntermediateTensors] = None,
366
        inputs_embeds: Optional[torch.Tensor] = None,
367
    ) -> Union[torch.Tensor, IntermediateTensors]:
368
        hidden_states = self.model(input_ids, positions, kv_caches,
369
370
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
371
372
        return hidden_states

373
374
375
376
377
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
378
        logits = self.logits_processor(self.lm_head, hidden_states,
379
380
381
                                       sampling_metadata)
        return logits

382
383
    def sample(
        self,
384
        logits: torch.Tensor,
385
        sampling_metadata: SamplingMetadata,
386
    ) -> Optional[SamplerOutput]:
387
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
388
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
389

390
391
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
392
393
394
395
396
397
398
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
399
        loaded_params: Set[str] = set()
400
        for name, loaded_weight in weights:
401
            if "lm_head.weight" in name and self.config.tie_word_embeddings:
Zhuohan Li's avatar
Zhuohan Li committed
402
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
403
404
405
            if name.startswith("decoder."):
                name = "model." + name

406
407
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
408
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
409
410
411
412
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
413
414
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
415
                param = params_dict[name]
416
417
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
418
                break
419
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
420
421
422
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
423
424
                if is_pp_missing_parameter(name, self):
                    continue
425
426
427
428
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
429
430
            loaded_params.add(name)
        return loaded_params