"vllm/entrypoints/openai/serving_pooling.py" did not exist on "e254497b66dcd87038969b0ad34d34425edfc5fe"
opt.py 16.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only OPT model compatible with HuggingFace weights."""
21
from typing import Iterable, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
22

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
23
24
25
26
import torch
from torch import nn
from transformers import OPTConfig

27
from vllm.attention import Attention
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, VllmConfig
30
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
31
from vllm.model_executor.layers.activation import get_act_fn
32
33
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
from .interfaces import SupportsPP
45
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
46
47
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
48

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
49
50
51
52

class OPTLearnedPositionalEmbedding(nn.Embedding):

    def __init__(self, num_embeddings: int, embedding_dim: int):
53
54
55
        # OPT is set up so that if padding_idx is specified then offset the
        # embedding ids by 2 and adjust num_embeddings appropriately. Other
        # models don't have this hack
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
56
57
58
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

59
    def forward(self, positions: torch.Tensor):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
        return super().forward(positions + self.offset)


class OPTAttention(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
70
        cache_config: Optional[CacheConfig] = None,
71
        quant_config: Optional[QuantizationConfig] = None,
72
        prefix: str = "",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
73
74
75
    ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
76
77
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
Zhuohan Li's avatar
Zhuohan Li committed
78
79
80
81
        total_num_heads = num_heads
        assert num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = embed_dim // total_num_heads
82
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
83

84
        self.qkv_proj = QKVParallelLinear(
85
            embed_dim,
86
87
            self.head_dim,
            total_num_heads,
88
            bias=bias,
89
            quant_config=quant_config,
90
            prefix=f"{prefix}.qkv_proj",
91
92
93
94
95
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=bias,
96
            quant_config=quant_config,
97
            prefix=f"{prefix}.out_proj",
98
        )
99
100
        self.attn = Attention(self.num_heads,
                              self.head_dim,
101
                              scale=self.scaling,
102
                              cache_config=cache_config,
103
104
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108
109

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
110
        qkv, _ = self.qkv_proj(hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        q, k, v = qkv.chunk(chunks=3, dim=-1)
112
        attn_output = self.attn(q, k, v)
Zhuohan Li's avatar
Zhuohan Li committed
113
        output, _ = self.out_proj(attn_output)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
114
115
        return output

Woosuk Kwon's avatar
Woosuk Kwon committed
116

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
117
118
class OPTDecoderLayer(nn.Module):

119
120
121
    def __init__(
        self,
        config: OPTConfig,
122
        cache_config: Optional[CacheConfig] = None,
123
        quant_config: Optional[QuantizationConfig] = None,
124
        prefix: str = "",
125
    ):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
126
        super().__init__()
Zhuohan Li's avatar
Zhuohan Li committed
127
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
128
129
130
131
132
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.enable_bias,
133
            cache_config=cache_config,
134
            quant_config=quant_config,
135
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
136
137
138
139
        )
        self.do_layer_norm_before = config.do_layer_norm_before

        self.self_attn_layer_norm = nn.LayerNorm(
140
141
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
142
143
144
145
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.ffn_dim,
            bias=config.enable_bias,
146
            quant_config=quant_config,
147
            prefix=f"{prefix}.fc1",
148
        )
149
        self.activation_fn = get_act_fn(config.activation_function)
150
151
152
153
        self.fc2 = RowParallelLinear(
            config.ffn_dim,
            self.embed_dim,
            bias=config.enable_bias,
154
            quant_config=quant_config,
155
            prefix=f"{prefix}.fc2",
156
        )
Zhuohan Li's avatar
Zhuohan Li committed
157
        self.final_layer_norm = nn.LayerNorm(
158
159
            self.embed_dim,
            elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
160

Woosuk Kwon's avatar
Woosuk Kwon committed
161
162
163
164
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
165
166
167
168
169
        # Self Attention
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)
170
        hidden_states = self.self_attn(hidden_states=hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
171
172
173
174
175
176
177
178
179
180
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
181
        hidden_states, _ = self.fc1(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
182
        hidden_states = self.activation_fn(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
183
        hidden_states, _ = self.fc2(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
184
185
186
187
188
189
190
        hidden_states = residual + hidden_states
        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)
        return hidden_states


Zhuohan Li's avatar
Zhuohan Li committed
191
class OPTDecoder(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
192

193
194
195
    def __init__(
        self,
        config: OPTConfig,
196
        cache_config: Optional[CacheConfig] = None,
197
        quant_config: Optional[QuantizationConfig] = None,
198
        prefix: str = "",
199
    ):
Zhuohan Li's avatar
Zhuohan Li committed
200
201
        super().__init__()
        self.config = config
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
202
203
204
        self.max_target_positions = config.max_position_embeddings
        self.vocab_size = config.vocab_size

205
206
207
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.word_embed_proj_dim,
208
        )
Zhuohan Li's avatar
Zhuohan Li committed
209
210
211
        # Positional embeddings are replicated (not sharded).
        self.embed_positions = OPTLearnedPositionalEmbedding(
            config.max_position_embeddings, config.hidden_size)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
212

Zhuohan Li's avatar
Zhuohan Li committed
213
        # Project out & in will be replicated if they exist.
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
214
        if config.word_embed_proj_dim != config.hidden_size:
215
216
217
            self.project_out = ReplicatedLinear(config.hidden_size,
                                                config.word_embed_proj_dim,
                                                bias=False,
218
219
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.project_out")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
220
221
222
223
        else:
            self.project_out = None

        if config.word_embed_proj_dim != config.hidden_size:
224
225
226
            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                               config.hidden_size,
                                               bias=False,
227
228
                                               quant_config=quant_config,
                                               prefix=f"{prefix}.project_in")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
229
230
231
        else:
            self.project_in = None

232
233
234
        # Note that the only purpose of `config._remove_final_layer_norm` is to
        # keep backward compatibility with checkpoints that have been fine-tuned
        # before transformers v4.20.1
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
235
236
237
        # see https://github.com/facebookresearch/metaseq/pull/164
        if config.do_layer_norm_before and not config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(
238
239
                config.hidden_size,
                elementwise_affine=config.layer_norm_elementwise_affine)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
240
241
242
        else:
            self.final_layer_norm = None

243
244
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
245
246
            lambda prefix: OPTDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
247
            prefix=f"{prefix}.layers")
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
248

249
250
251
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
252
253
    def forward(
        self,
254
255
        input_ids: torch.Tensor,
        positions: torch.Tensor,
256
        intermediate_tensors: Optional[IntermediateTensors],
257
        inputs_embeds: Optional[torch.Tensor] = None,
258
259
260
261
262
263
264
265
266
267
268
269
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
            pos_embeds = self.embed_positions(positions)
            if self.project_in is not None:
                inputs_embeds, _ = self.project_in(inputs_embeds)
            hidden_states = inputs_embeds + pos_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

270
271
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
272

273
274
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
275
276
277
        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)
        if self.project_out is not None:
278
            hidden_states, _ = self.project_out(hidden_states)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
279
280
281
        return hidden_states


282
@support_torch_compile
Zhuohan Li's avatar
Zhuohan Li committed
283
class OPTModel(nn.Module):
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
284

285
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Zhuohan Li's avatar
Zhuohan Li committed
286
        super().__init__()
287
288
289
290
291

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

292
293
294
295
        self.decoder = OPTDecoder(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.decoder")
296
297
298
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
299

300
301
302
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.decoder.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
303
304
    def forward(
        self,
305
306
        input_ids: torch.Tensor,
        positions: torch.Tensor,
307
        intermediate_tensors: Optional[IntermediateTensors],
308
        inputs_embeds: Optional[torch.Tensor] = None,
309
    ) -> Union[torch.Tensor, IntermediateTensors]:
310
311
        return self.decoder(input_ids,
                            positions,
312
                            intermediate_tensors,
313
                            inputs_embeds=inputs_embeds)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
314

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
352

353
class OPTForCausalLM(nn.Module, SupportsPP):
354
355
356
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
357
358
    }

359
360
361
362
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
        "decoder.": "model.decoder.",
    })

363
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
364
365
366
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Zhuohan Li's avatar
Zhuohan Li committed
367
        self.config = config
368
        self.quant_config = quant_config
369
        self.model = OPTModel(vllm_config=vllm_config,
370
                              prefix=maybe_prefix(prefix, "model"))
371
372
373
374
375
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.decoder.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.word_embed_proj_dim)
376
        self.logits_processor = LogitsProcessor(config.vocab_size)
377
378
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
379

380
381
382
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Add OPT  
Woosuk Kwon committed
383
384
    def forward(
        self,
385
386
        input_ids: torch.Tensor,
        positions: torch.Tensor,
387
        intermediate_tensors: Optional[IntermediateTensors] = None,
388
        inputs_embeds: Optional[torch.Tensor] = None,
389
    ) -> Union[torch.Tensor, IntermediateTensors]:
390
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
391
                                   inputs_embeds)
392
393
        return hidden_states

394
395
396
397
398
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
399
        logits = self.logits_processor(self.lm_head, hidden_states,
400
401
402
                                       sampling_metadata)
        return logits

403
404
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
405
406
407
408
409
410
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head.weight"]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)