"tests/vscode:/vscode.git/clone" did not exist on "ac32e66cf95c40502ad7b00b2e80dfb0315bfee4"
transformers.py 9.82 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
18
from typing import Iterable, Literal, Optional, Union
19
20
21
22
23
24

import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

25
from vllm.attention import Attention
26
27
28
29
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
30
                                               ReplicatedLinear,
31
32
33
34
35
36
37
38
39
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

40
from .interfaces import SupportsLoRA, SupportsQuant
41
42
43
44
45
46
47
48
49
50
51
52
53
from .utils import maybe_prefix

logger = init_logger(__name__)


def vllm_flash_attention_forward(
        # Transformers args
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor,
        # Transformers kwargs
54
        scaling: Optional[float] = None,
55
        # vLLM kwargs
56
        attention_instances: Optional[list[Attention]] = None,
57
58
59
60
61
62
63
        **kwargs):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
64
    return self_attn.forward(query, key, value), None
65
66
67
68
69


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


70
71
72
73
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


74
75
def replace_linear_class(
        linear: nn.Linear,
76
        style: Literal["colwise", "rowwise"],
77
        quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
78
    """
79
80
81
82
83
84
85
86
87
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
    
    `quant_config` is not yet supported.
    Args:
        linear (nn.Linear): `nn.Linear` to be replaced.
        style (str): Tensor parallel style of the new linear, e.g. "colwise".
        quant_config (QuantConfig): Quantization config for the new linear.
    Returns:
        Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
88
89
90
91
92
93
    """

    if not isinstance(style, str):
        raise ValueError(
            f"Unsupported parallel style type {type(style)}, expected str")

94
95
96
    vllm_linear_cls = {
        "colwise": ColumnParallelLinear,
        "rowwise": RowParallelLinear,
97
    }.get(style, ReplicatedLinear)
98

99
    return vllm_linear_cls(
100
101
102
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
103
        quant_config=quant_config,
104
        return_bias=False,
105
106
    )

107

108
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
109
110
111
112
113
114
115
116
117
118
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"
                         ]  # TODO transformers will have a util to get it

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        logger.info("Using Transformers backend.")

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
119
120
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
121

122
        self.config = config
123
124
        self.vocab_size = model_config.get_vocab_size()
        self.unpadded_vocab_size = model_config.get_vocab_size()
125
126
127
128

        self.model: PreTrainedModel = AutoModel.from_config(
            self.config,
            attn_implementation="vllm",
129
            torch_dtype=vllm_config.model_config.dtype,
130
131
132
133
134
            trust_remote_code=vllm_config.model_config.trust_remote_code,
        )
        prefix = self.model.base_model_prefix

        # MLP modifications
135
        self.apply_base_model_tp_plan(self.model)
136
137

        # Attention modifications (assumes 1 attention op per hidden layer)
138
139
140
        num_heads = model_config.get_num_attention_heads(parallel_config)
        head_size = model_config.get_head_size()
        num_kv_heads = model_config.get_num_kv_heads(parallel_config)
141
142
        self.attention_instances = [
            Attention(
143
144
                num_heads=num_heads,
                head_size=head_size,
145
146
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
147
148
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
149
                cache_config=cache_config,
150
                quant_config=self.quant_config,
151
152
153
154
155
156
157
                prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
        ]

        # Model modifications
        self.replace_vocab_embed_class(self.model)

        # ForCausalLM modifications
158
        self.lm_head = ParallelLMHead(self.vocab_size,
159
                                      config.hidden_size,
160
                                      quant_config=self.quant_config,
161
162
163
164
165
166
                                      prefix=maybe_prefix(prefix, "lm_head"))
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.get_input_embeddings().weight

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
167
                                                self.vocab_size, logit_scale)
168
169
        self.sampler = get_sampler()

170
171
172
173
174
    def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
        """
        Apply the base model tensor parallelization plan to a module.
        Currently only supports linear layers.
        """
175
        if (self.config.base_model_tp_plan is None
176
                and get_tensor_model_parallel_world_size() > 1):
177
178
179
180
181
            raise ValueError(
                "Trying to run tensor parallelization but the model does not "
                "support it yet!")

        for child_name, child_module in module.named_children():
182
            qual_name = maybe_prefix(prefix, child_name)
183
184
185
            for pattern, style in self.config.base_model_tp_plan.items():
                if re.match(pattern, qual_name) and isinstance(
                        child_module, nn.Linear):
186
187
                    new_module = replace_linear_class(child_module, style,
                                                      self.quant_config)
188
                    setattr(module, child_name, new_module)
189
                    log_replacement(qual_name, child_module, new_module)
190
            else:
191
                self.apply_base_model_tp_plan(child_module, prefix=qual_name)
192
193
194
195
196
197

    def replace_vocab_embed_class(self, module: nn.Module):
        # Use native set input embeddings
        new_module = VocabParallelEmbedding(
            self.vocab_size,
            self.config.hidden_size,
198
            org_num_embeddings=self.vocab_size,
199
200
            quant_config=None,
        )
201
202
        log_replacement("input embedding", self.model.get_input_embeddings(),
                        new_module)
203
        module.set_input_embeddings(new_module)
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(
            input_ids[None, ...],
            use_cache=False,
            position_ids=positions[None, ...],
            intermediate_tensors=intermediate_tensors,
            attention_instances=self.attention_instances,
            return_dict=False)[0][0, ...]  # we remove batch dimension for now
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:

        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

236
237
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
238
        params_dict = dict(self.named_parameters())
239
        loaded_params = set[str]()
240
241
242
243
244
245
246
247
248
249
        for name, loaded_weight in weights:
            if name not in params_dict:
                name = f"{self.model.base_model_prefix}.{name}"
            if name in params_dict:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
        return loaded_params