"vscode:/vscode.git/clone" did not exist on "55a1a9563a7f8600cdc336e76d2074cef8ffe8e5"
transformers.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
18
from typing import Iterable, Literal, Optional, Union
19
20
21
22
23
24
25
26
27
28
29

import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from vllm.attention import Attention, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.logger import init_logger
30
31
32
33
34
from vllm.lora.fully_sharded_layers import (
    ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA)
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
                              ReplicatedLinearWithLoRA,
                              RowParallelLinearWithLoRA)
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
36
                                               ReplicatedLinear,
37
38
39
40
41
42
43
44
45
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

46
from .interfaces import SupportsQuant
47
48
49
50
51
52
53
54
55
56
57
58
59
from .utils import maybe_prefix

logger = init_logger(__name__)


def vllm_flash_attention_forward(
        # Transformers args
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor,
        # Transformers kwargs
60
        scaling: Optional[float] = None,
61
        # vLLM kwargs
62
63
        attn_metadata: Optional[AttentionMetadata] = None,
        attention_instances: Optional[list[Attention]] = None,
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        **kwargs):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
    return self_attn.forward(
        query,
        key,
        value,
        kv_cache=None,  # argument not used
        attn_metadata=attn_metadata), None


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


82
83
84
85
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


86
87
def replace_linear_class(
        linear: nn.Linear,
88
        style: Literal["colwise", "rowwise"],
89
        quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
90
    """
91
92
93
94
95
96
97
98
99
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
    
    `quant_config` is not yet supported.
    Args:
        linear (nn.Linear): `nn.Linear` to be replaced.
        style (str): Tensor parallel style of the new linear, e.g. "colwise".
        quant_config (QuantConfig): Quantization config for the new linear.
    Returns:
        Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
100
101
102
103
104
105
    """

    if not isinstance(style, str):
        raise ValueError(
            f"Unsupported parallel style type {type(style)}, expected str")

106
107
108
    vllm_linear_cls = {
        "colwise": ColumnParallelLinear,
        "rowwise": RowParallelLinear,
109
    }.get(style, ReplicatedLinear)
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    lora_linear_cls = {
        ColumnParallelLinear: {
            True: ColumnParallelLinearWithShardedLoRA,  # fully sharded
            False: ColumnParallelLinearWithLoRA  # not fully sharded
        },
        RowParallelLinear: {
            True: RowParallelLinearWithShardedLoRA,
            False: RowParallelLinearWithLoRA
        },
        # ReplicatedLinear doesn't support fully sharded LoRA yet,
        # so we use the same class for both cases.
        ReplicatedLinear: {
            True: ReplicatedLinearWithLoRA,
            False: ReplicatedLinearWithLoRA
        }
    }

128
129
130
131
132
133
134
135
    class HFCompatibleLinear(vllm_linear_cls):
        """
        Wrapper class that removes `output_bias` from returned output.
        """

        def forward(self, input: torch.Tensor) -> torch.Tensor:
            return super().forward(input)[0]

136
137
138
139
140
141
142
143
144
145
146
147
148
        @classmethod
        def get_lora_class(cls, fully_sharded: bool = False):
            """
            Get the LoRA class corresponding to the current transformer
            linear class.

            Args:
                fully_sharded (bool): If True, select the LoRA class variant
                that supports fully sharded LoRA. Defaults to False.

            """
            return lora_linear_cls[vllm_linear_cls][fully_sharded]

149
150
151
152
    return HFCompatibleLinear(
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
153
        quant_config=quant_config,
154
155
    )

156

157
class TransformersModel(nn.Module, SupportsQuant):
158
159
160
161
162
163
164
165
166
167
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"
                         ]  # TODO transformers will have a util to get it

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        logger.info("Using Transformers backend.")

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
168

169
170
171
172
173
174
175
        self.config = config
        self.vocab_size = config.vocab_size
        self.unpadded_vocab_size = config.vocab_size

        self.model: PreTrainedModel = AutoModel.from_config(
            self.config,
            attn_implementation="vllm",
176
            torch_dtype=vllm_config.model_config.dtype,
177
178
179
180
181
            trust_remote_code=vllm_config.model_config.trust_remote_code,
        )
        prefix = self.model.base_model_prefix

        # MLP modifications
182
        self.apply_base_model_tp_plan(self.model)
183
184
185
186
187
188
189
190
191
192
193
194

        # Attention modifications (assumes 1 attention op per hidden layer)
        tp_size = get_tensor_model_parallel_world_size()
        self.attention_instances = [
            Attention(
                num_heads=divide(config.num_attention_heads, tp_size),
                head_size=config.head_dim,
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
                scale=config.head_dim**-0.5,
                num_kv_heads=divide(config.num_key_value_heads, tp_size),
                cache_config=cache_config,
195
                quant_config=self.quant_config,
196
197
198
199
200
201
202
203
204
                prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
        ]

        # Model modifications
        self.replace_vocab_embed_class(self.model)

        # ForCausalLM modifications
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
205
                                      quant_config=self.quant_config,
206
207
208
209
210
211
212
213
214
                                      prefix=maybe_prefix(prefix, "lm_head"))
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.get_input_embeddings().weight

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = get_sampler()

215
216
217
218
219
    def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
        """
        Apply the base model tensor parallelization plan to a module.
        Currently only supports linear layers.
        """
220
        if (self.config.base_model_tp_plan is None
221
                and get_tensor_model_parallel_world_size() > 1):
222
223
224
225
226
            raise ValueError(
                "Trying to run tensor parallelization but the model does not "
                "support it yet!")

        for child_name, child_module in module.named_children():
227
            qual_name = maybe_prefix(prefix, child_name)
228
229
230
            for pattern, style in self.config.base_model_tp_plan.items():
                if re.match(pattern, qual_name) and isinstance(
                        child_module, nn.Linear):
231
232
                    new_module = replace_linear_class(child_module, style,
                                                      self.quant_config)
233
                    setattr(module, child_name, new_module)
234
                    log_replacement(qual_name, child_module, new_module)
235
            else:
236
                self.apply_base_model_tp_plan(child_module, prefix=qual_name)
237
238
239
240
241
242
243
244
245

    def replace_vocab_embed_class(self, module: nn.Module):
        # Use native set input embeddings
        new_module = VocabParallelEmbedding(
            self.vocab_size,
            self.config.hidden_size,
            org_num_embeddings=self.config.vocab_size,
            quant_config=None,
        )
246
247
        log_replacement("input embedding", self.model.get_input_embeddings(),
                        new_module)
248
249
250
251
252
253
        self.model.set_input_embeddings(new_module)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
254
        kv_caches: list[torch.Tensor],  # argument not used
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(
            input_ids[None, ...],
            use_cache=False,
            position_ids=positions[None, ...],
            attn_metadata=attn_metadata,
            intermediate_tensors=intermediate_tensors,
            attention_instances=self.attention_instances,
            return_dict=False)[0][0, ...]  # we remove batch dimension for now
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:

        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

284
285
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
286
        params_dict = dict(self.named_parameters())
287
        loaded_params = set[str]()
288
289
290
291
292
293
294
295
296
297
        for name, loaded_weight in weights:
            if name not in params_dict:
                name = f"{self.model.base_model_prefix}.{name}"
            if name in params_dict:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
        return loaded_params