opt.py 17 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only OPT model compatible with HuggingFace weights."""
21
from typing import Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
22

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
23
24
25
26
import torch
from torch import nn
from transformers import OPTConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, VllmConfig
30
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
31
from vllm.model_executor.layers.activation import get_act_fn
32
33
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
38
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
44

45
46
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
49

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
50
51
52
53

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
54
55
56
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
57
58
59
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

60
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
61
62
63
64
65
66
67
68
69
70
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
71
        cache_config: Optional[CacheConfig] = None,
72
        quant_config: Optional[QuantizationConfig] = None,
73
        prefix: str = "",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
74
75
76
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
77
78
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
79
80
81
82
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
83
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
84

85
        self.qkv_proj = QKVParallelLinear(
86
            embed_dim,
87
88
            self.head_dim,
            total_num_heads,
89
            bias=bias,
90
            quant_config=quant_config,
91
            prefix=f"{prefix}.qkv_proj",
92
93
94
95
96
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
97
            quant_config=quant_config,
98
            prefix=f"{prefix}.out_proj",
99
        )
100
101
        self.attn = Attention(self.num_heads,
                              self.head_dim,
102
                              scale=self.scaling,
103
                              cache_config=cache_config,
104
105
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
108
109

    def forward(
        self,
        hidden_states: torch.Tensor,
110
111
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
112
    ) -> torch.Tensor:
113
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        q, k, v = qkv.chunk(chunks=3, dim=-1)
115
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
116
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
117
118
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
119

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
120
121
class OPTDecoderLayer(nn.Module):

122
123
124
    def __init__(
        self,
        config: OPTConfig,
125
        cache_config: Optional[CacheConfig] = None,
126
        quant_config: Optional[QuantizationConfig] = None,
127
        prefix: str = "",
128
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
129
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
130
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
131
132
133
134
135
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
136
            cache_config=cache_config,
137
            quant_config=quant_config,
138
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
139
140
141
142
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
143
144
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
145
146
147
148
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
149
            quant_config=quant_config,
150
            prefix=f"{prefix}.fc1",
151
        )
152
        self.activation_fn = get_act_fn(config.activation_function)
153
154
155
156
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
157
            quant_config=quant_config,
158
            prefix=f"{prefix}.fc2",
159
        )
Zhuohan Li's avatar
Zhuohan Li committed
160
        self.final_layer_norm = nn.LayerNorm(
161
162
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
163

Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166
    def forward(
        self,
        hidden_states: torch.Tensor,
167
168
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
169
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
170
171
172
173
174
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
175
176
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
177
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
178
179
180
181
182
183
184
185
186
187
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
188
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
189
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
190
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
191
192
193
194
195
196
197
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
198
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
199

200
201
202
    def __init__(
        self,
        config: OPTConfig,
203
        cache_config: Optional[CacheConfig] = None,
204
        quant_config: Optional[QuantizationConfig] = None,
205
        prefix: str = "",
206
    ):
Zhuohan Li's avatar
Zhuohan Li committed
207
208
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
209
210
211
212
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

213
214
215
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
216
        )
Zhuohan Li's avatar
Zhuohan Li committed
217
218
219
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
220

Zhuohan Li's avatar
Zhuohan Li committed
221
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
222
        if config.word_embed_proj_dim != config.hidden_size:
223
224
225
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
226
227
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
228
229
230
231
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
232
233
234
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
235
236
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
237
238
239
        else:
            self.project_in = None

240
241
242
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
243
244
245
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
246
247
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
248
249
250
        else:
            self.final_layer_norm = None

251
252
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
253
254
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
255
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
256

257
258
259
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
260
261
    def forward(
        self,
262
263
        input_ids: torch.Tensor,
        positions: torch.Tensor,
264
265
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
266
        intermediate_tensors: Optional[IntermediateTensors],
267
        inputs_embeds: Optional[torch.Tensor] = None,
268
269
270
271
272
273
274
275
276
277
278
279
280
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
281
            layer = self.layers[i]
282
283
284
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
285

286
287
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
288
289
290
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
291
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
292
293
294
        return hidden_states


295
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
296
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
297

298
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Zhuohan Li's avatar
Zhuohan Li committed
299
        super().__init__()
300
301
302
303
304

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

305
306
307
308
        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
309
310
311
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
312

313
314
315
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
316
317
    def forward(
        self,
318
319
        input_ids: torch.Tensor,
        positions: torch.Tensor,
320
321
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
322
        intermediate_tensors: Optional[IntermediateTensors],
323
        inputs_embeds: Optional[torch.Tensor] = None,
324
    ) -> Union[torch.Tensor, IntermediateTensors]:
325
326
327
328
        return self.decoder(input_ids,
                            positions,
                            kv_caches,
                            attn_metadata,
329
                            intermediate_tensors,
330
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
331
332


333
class OPTForCausalLM(nn.Module, SupportsPP):
334
335
336
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
337
338
    }

339
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
340
341
342
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Zhuohan Li's avatar
Zhuohan Li committed
343
344
        super().__init__()
        self.config = config
345
        self.quant_config = quant_config
346
        self.model = OPTModel(vllm_config=vllm_config,
347
                              prefix=maybe_prefix(prefix, "model"))
348
349
350
351
352
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
353
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
354
        self.sampler = get_sampler()
355
356
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
357

358
359
360
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
361
362
    def forward(
        self,
363
364
        input_ids: torch.Tensor,
        positions: torch.Tensor,
365
366
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
367
        intermediate_tensors: Optional[IntermediateTensors] = None,
368
        inputs_embeds: Optional[torch.Tensor] = None,
369
    ) -> Union[torch.Tensor, IntermediateTensors]:
370
        hidden_states = self.model(input_ids, positions, kv_caches,
371
372
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
373
374
        return hidden_states

375
376
377
378
379
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
380
        logits = self.logits_processor(self.lm_head, hidden_states,
381
382
383
                                       sampling_metadata)
        return logits

384
385
    def sample(
        self,
386
        logits: torch.Tensor,
387
        sampling_metadata: SamplingMetadata,
388
    ) -> Optional[SamplerOutput]:
389
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
390
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
391

392
393
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
394
395
396
397
398
399
400
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
401
        loaded_params: Set[str] = set()
402
        for name, loaded_weight in weights:
403
            if "lm_head.weight" in name and self.config.tie_word_embeddings:
Zhuohan Li's avatar
Zhuohan Li committed
404
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
407
            if name.startswith("decoder."):
                name = "model." + name

408
409
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
410
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
411
412
413
414
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
415
416
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
417
                param = params_dict[name]
418
419
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
420
                break
421
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
422
423
424
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
425
426
                if is_pp_missing_parameter(name, self):
                    continue
427
428
429
430
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
431
432
            loaded_params.add(name)
        return loaded_params