opt.py 15.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only OPT model compatible with HuggingFace weights."""
20
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
21

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
22
23
24
25
import torch
from torch import nn
from transformers import OPTConfig

26
from vllm.attention import Attention, AttentionMetadata
27
from vllm.config import CacheConfig
28
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.activation import get_act_fn
30
31
32
33
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
34
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
37
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
    ParallelLMHead, VocabParallelEmbedding)
39
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
from vllm.model_executor.sampling_metadata import SamplingMetadata
41
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
42

43
44
45
46
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
47
48
49
50

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
51
52
53
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
54
55
56
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

57
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
58
59
60
61
62
63
64
65
66
67
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
68
        cache_config: Optional[CacheConfig] = None,
69
        quant_config: Optional[QuantizationConfig] = None,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
70
71
72
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
73
74
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
75
76
77
78
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
79
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
80

81
        self.qkv_proj = QKVParallelLinear(
82
            embed_dim,
83
84
            self.head_dim,
            total_num_heads,
85
            bias=bias,
86
            quant_config=quant_config,
87
88
89
90
91
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
92
            quant_config=quant_config,
93
        )
94
95
        self.attn = Attention(self.num_heads,
                              self.head_dim,
96
                              scale=self.scaling,
97
98
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102

    def forward(
        self,
        hidden_states: torch.Tensor,
103
104
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
105
    ) -> torch.Tensor:
106
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        q, k, v = qkv.chunk(chunks=3, dim=-1)
108
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Zhuohan Li's avatar
Zhuohan Li committed
109
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
110
111
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
112

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
113
114
class OPTDecoderLayer(nn.Module):

115
116
117
    def __init__(
        self,
        config: OPTConfig,
118
        cache_config: Optional[CacheConfig] = None,
119
        quant_config: Optional[QuantizationConfig] = None,
120
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
121
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
122
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
123
124
125
126
127
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
128
            cache_config=cache_config,
129
            quant_config=quant_config,
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
130
131
132
133
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
134
135
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
136
137
138
139
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
140
            quant_config=quant_config,
141
        )
142
143
        self.activation_fn = get_act_fn(config.activation_function,
                                        quant_config, config.ffn_dim)
144
145
146
147
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
148
            quant_config=quant_config,
149
        )
Zhuohan Li's avatar
Zhuohan Li committed
150
        self.final_layer_norm = nn.LayerNorm(
151
152
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
153

Woosuk Kwon's avatar
Woosuk Kwon committed
154
155
156
    def forward(
        self,
        hidden_states: torch.Tensor,
157
158
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
159
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
160
161
162
163
164
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
165
166
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       kv_cache=kv_cache,
167
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
168
169
170
171
172
173
174
175
176
177
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
178
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
179
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
180
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
181
182
183
184
185
186
187
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
188
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
189

190
191
192
    def __init__(
        self,
        config: OPTConfig,
193
        cache_config: Optional[CacheConfig] = None,
194
        quant_config: Optional[QuantizationConfig] = None,
195
        prefix: str = "",
196
    ):
Zhuohan Li's avatar
Zhuohan Li committed
197
198
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
199
200
201
202
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

203
204
205
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
206
        )
Zhuohan Li's avatar
Zhuohan Li committed
207
208
209
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
210

Zhuohan Li's avatar
Zhuohan Li committed
211
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
212
        if config.word_embed_proj_dim != config.hidden_size:
213
214
215
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
216
                                                quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
217
218
219
220
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
221
222
223
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
224
                                               quant_config=quant_config)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
225
226
227
        else:
            self.project_in = None

228
229
230
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
231
232
233
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
234
235
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
236
237
238
        else:
            self.final_layer_norm = None

239
240
241
242
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: OPTDecoderLayer(config, cache_config, quant_config),
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
243

244
245
246
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
247
248
    def forward(
        self,
249
250
        input_ids: torch.Tensor,
        positions: torch.Tensor,
251
252
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
253
        intermediate_tensors: Optional[IntermediateTensors],
254
        inputs_embeds: Optional[torch.Tensor] = None,
255
256
257
258
259
260
261
262
263
264
265
266
267
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
268
            layer = self.layers[i]
269
270
271
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
272

273
274
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
275
276
277
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
278
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
279
280
281
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
282
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
283

284
285
286
    def __init__(
        self,
        config: OPTConfig,
287
        cache_config: Optional[CacheConfig] = None,
288
        quant_config: Optional[QuantizationConfig] = None,
289
    ):
Zhuohan Li's avatar
Zhuohan Li committed
290
        super().__init__()
291
        self.decoder = OPTDecoder(config, cache_config, quant_config)
292
293
294
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
295

296
297
298
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
299
300
    def forward(
        self,
301
302
        input_ids: torch.Tensor,
        positions: torch.Tensor,
303
304
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
305
        intermediate_tensors: Optional[IntermediateTensors],
306
        inputs_embeds: Optional[torch.Tensor] = None,
307
    ) -> Union[torch.Tensor, IntermediateTensors]:
308
309
310
311
        return self.decoder(input_ids,
                            positions,
                            kv_caches,
                            attn_metadata,
312
                            intermediate_tensors,
313
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
314
315


316
class OPTForCausalLM(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
317

318
319
    def __init__(
        self,
320
        config: OPTConfig,
321
        cache_config: Optional[CacheConfig] = None,
322
        quant_config: Optional[QuantizationConfig] = None,
323
    ):
Zhuohan Li's avatar
Zhuohan Li committed
324
325
        super().__init__()
        self.config = config
326
        self.quant_config = quant_config
327
        self.model = OPTModel(config, cache_config, quant_config)
328
329
330
331
332
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
333
334
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
335
336
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
337
338
339

    def forward(
        self,
340
341
        input_ids: torch.Tensor,
        positions: torch.Tensor,
342
343
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
344
        intermediate_tensors: Optional[IntermediateTensors] = None,
345
    ) -> Union[torch.Tensor, IntermediateTensors]:
346
        hidden_states = self.model(input_ids, positions, kv_caches,
347
                                   attn_metadata, intermediate_tensors)
348
349
        return hidden_states

350
351
352
353
354
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
355
        logits = self.logits_processor(self.lm_head, hidden_states,
356
357
358
                                       sampling_metadata)
        return logits

359
360
    def sample(
        self,
361
        logits: torch.Tensor,
362
        sampling_metadata: SamplingMetadata,
363
    ) -> Optional[SamplerOutput]:
364
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
365
        return next_tokens
Zhuohan Li's avatar
Zhuohan Li committed
366

367
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
368
369
370
371
372
373
374
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
375
        for name, loaded_weight in weights:
376
            if "lm_head.weight" in name and self.config.tie_word_embeddings:
Zhuohan Li's avatar
Zhuohan Li committed
377
                continue
Woosuk Kwon's avatar
Woosuk Kwon committed
378
379
380
            if name.startswith("decoder."):
                name = "model." + name

381
382
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
383
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
384
385
386
387
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
388
389
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
390
                param = params_dict[name]
391
392
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
393
                break
394
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
395
396
397
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
398
399
                if is_pp_missing_parameter(name, self):
                    continue
400
401
402
403
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)