"vscode:/vscode.git/clone" did not exist on "80bb1e8afe950342e93b7262e7bf25eb6d29b287"
transformers.py 17.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
18
from typing import Iterable, Literal, Optional, Union
19
20
21

import torch
from torch import nn
22
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
23
24
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

25
from vllm.attention import Attention
26
from vllm.compilation.decorators import support_torch_compile
27
28
29
30
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
31
32
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
33
                                               ReplicatedLinear,
34
35
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
38
39
40
41
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

43
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
44
45
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
                    is_pp_missing_parameter,
46
                    make_empty_intermediate_tensors_factory, maybe_prefix)
47
48
49
50
51
52
53
54
55
56
57
58

logger = init_logger(__name__)


def vllm_flash_attention_forward(
        # Transformers args
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor,
        # Transformers kwargs
59
        scaling: Optional[float] = None,
60
        # vLLM kwargs
61
        attention_instances: Optional[dict[Attention]] = None,
62
63
64
65
66
67
68
        **kwargs):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
69
    return self_attn.forward(query, key, value), None
70
71
72
73
74


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


75
76
77
78
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


79
def replace_linear_class(
80
81
82
    linear: nn.Linear, style: Literal["colwise", "rowwise"],
    quant_config: QuantizationConfig
) -> Union[ColumnParallelLinear, RowParallelLinear]:
83
    """
84
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
85

86
87
88
89
90
91
    Args:
        linear (nn.Linear): `nn.Linear` to be replaced.
        style (str): Tensor parallel style of the new linear, e.g. "colwise".
        quant_config (QuantConfig): Quantization config for the new linear.
    Returns:
        Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
92
93
94
95
96
97
    """

    if not isinstance(style, str):
        raise ValueError(
            f"Unsupported parallel style type {type(style)}, expected str")

98
99
100
    vllm_linear_cls = {
        "colwise": ColumnParallelLinear,
        "rowwise": RowParallelLinear,
101
    }.get(style, ReplicatedLinear)
102

103
    return vllm_linear_cls(
104
105
106
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
107
        quant_config=quant_config,
108
        return_bias=False,
109
110
    )

111

112
class TransformersModel(nn.Module):
113

114
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
115
116
117
        super().__init__()
        logger.info("Using Transformers backend.")

118
119
120
121
122
123
        config: PretrainedConfig = vllm_config.model_config.hf_config
        cache_config: CacheConfig = vllm_config.cache_config
        device_config: DeviceConfig = vllm_config.device_config
        model_config: ModelConfig = vllm_config.model_config
        parallel_config: ParallelConfig = vllm_config.parallel_config
        quant_config: QuantizationConfig = vllm_config.quant_config
124

125
        self.config = config
126
127
128
129
130
131
132
133
134
135
136
137
138
        self.cache_config = cache_config
        self.device_config = device_config
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.quant_config = quant_config

        self.pp_group = get_pp_group()
        self.pp_size = self.pp_group.world_size
        self.pp_rank = self.pp_group.rank_in_group
        self.tp_size = get_tensor_model_parallel_world_size()

        # Use meta device to delay allocating GPU tensors
        with torch.device("meta"):
139
140
141
            # FIXME(Isotr0py): We need to refactor this part in the future to
            # avoid registering an extra model layer, otherwise we will need a
            # weights mapper to rename weights.
142
143
144
145
146
147
            self.model: PreTrainedModel = AutoModel.from_config(
                config,
                attn_implementation="vllm",
                torch_dtype=model_config.dtype,
                trust_remote_code=model_config.trust_remote_code,
            )
148

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        self.pipeline_parallel()
        self.tensor_parallel()

        # Input embeddings
        if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
                    config.vocab_size,
                    config.hidden_size,
                    org_num_embeddings=config.vocab_size,
                    quant_config=quant_config,
                ))

        # Attention layers
        self.attention_instances = self.create_attention_instances()

        # Initialize buffers (e.g. rotary embedding inverse frequency)
        self.init_buffers(self.model)

168
        # Initialize any parameters that have not had their modules replaced
169
170
        self.init_parameters(self.model)

171
172
173
174
175
176
177
178
179
180
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_size <= 1:
            return
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if not self.model.supports_pp_plan:
            raise ValueError(
                f"{type(self.model)} does not support pipeline parallel yet!")

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
                "in the base model are not supported yet!")
        if module_list_idx is None:
            raise ValueError(
                f"Could not find `ModuleList` in {type(self.model)}")

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
            if self.pp_group.is_first_rank or (self.config.tie_word_embeddings
                                               and self.pp_group.is_last_rank):
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
        start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers,
                                                self.pp_rank, self.pp_size)
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
            layers[i] = PPMissingLayer(return_tuple=True)

        # Layers after module list
        for name in pp_plan[module_list_idx + 1:]:
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

    def tensor_parallel(self):
        """
        Apply the model's tensor parallelization plan.
        Currently only supports linear layers.
        """
230
231
232
233
        if not self.model.supports_tp_plan:
            if self.tp_size <= 1:
                return

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
            raise ValueError(
                f"{type(self.model)} does not support tensor parallel yet!")

        tp_plan = self.model._tp_plan

        def _tensor_parallel(module: nn.Module, prefix: str = ""):
            for child_name, child_module in module.named_children():
                qual_name = maybe_prefix(prefix, child_name)
                for pattern, style in tp_plan.items():
                    if re.match(pattern, qual_name) and isinstance(
                            child_module, nn.Linear):
                        new_module = replace_linear_class(
                            child_module, style, self.quant_config)
                        setattr(module, child_name, new_module)
                        log_replacement(qual_name, child_module, new_module)
                else:
                    _tensor_parallel(child_module, prefix=qual_name)

        _tensor_parallel(self.model)

    def create_attention_instances(self) -> dict[int, Attention]:
        """
        Create `Attention` instances to inform KV cache allocation.
        """
        num_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        start, end = get_pp_indices(self.config.num_hidden_layers,
                                    self.pp_rank, self.pp_size)
        return {
            i:
266
            Attention(
267
268
                num_heads=num_heads,
                head_size=head_size,
269
270
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
271
272
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
273
                cache_config=self.cache_config,
274
                quant_config=self.quant_config,
275
276
277
                prefix=f"{i}.attn")
            for i in range(start, end)
        }
278

279
    def init_buffers(self, module: nn.Module):
280
        """
281
282
283
284
285
286
287
288
289
290
291
        If a `buffer` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```

        This means that:
        - `type(module)` is a class from `transformers`
        - This class is constructed using a `PretrainedConfig`
292
        """
293
294
        for name, buffer in module.named_buffers(recurse=False):
            if buffer.device == torch.device("meta"):
295
296
297
298
299
300
301
302
                if module == self.model:
                    logger.warning(
                        "To initialize buffers correctly, we instantiate the "
                        "parent module and and extract the value of the "
                        "buffer from it. In this case, the parent module is "
                        "the base model. Instantiating the entire model here "
                        "risks GPU OOM. Could this buffer be moved to a child "
                        "module?")
303
304
305
306
307
                new_buffer = getattr(type(module)(self.config), name)
                setattr(module, name, new_buffer)
        for child in module.children():
            self.init_buffers(child)

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    def init_parameters(self, module: nn.Module):
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
        for name, param in module.named_parameters(recurse=False):
            if param.device == torch.device("meta"):
                new_param = nn.Parameter(
                    torch.empty_like(param.data,
                                     device=self.device_config.device))
                setattr(module, name, new_param)
        for child in module.children():
            self.init_parameters(child)

327
328
329
    def get_input_embeddings(self) -> nn.Module:
        return self.model.get_input_embeddings()

330
331
    def forward(
        self,
332
        input_ids: Optional[torch.Tensor],
333
334
335
336
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
337
338
339
340
341
342
343
344
345
346
347
348
349
        if not get_pp_group().is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
350
351
352
353
            use_cache=False,
            position_ids=positions[None, ...],
            attention_instances=self.attention_instances,
            return_dict=False)[0][0, ...]  # we remove batch dimension for now
354
355
356
357
358

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
359

360
361
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
362
        params_dict = dict(self.named_parameters())
363

364
        loaded_params = set[str]()
365
        for name, loaded_weight in weights:
366
367
368
369
370
            # Use "model" instead of base_model_prefix because
            # the base model attribute in vLLM is always `model`
            if not name.startswith(prefix := "model."):
                name = prefix + name

371
372
            if is_pp_missing_parameter(name, self):
                continue
373
374
375
376
377
378
379
            if name in params_dict:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
        return loaded_params
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419


@support_torch_compile
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
                              SupportsPP):
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"
                         ]  # TODO transformers will have a util to get it

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: PretrainedConfig = vllm_config.model_config.hf_config
        quant_config: QuantizationConfig = vllm_config.quant_config

        self.config = config

        self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)

        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            if config.tie_word_embeddings:
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings())

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

420
    # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    # this makes thing complicated. We need to remove this mapper after refactor
    # `TransformersModel` in the future.
    @property
    def hf_to_vllm_mapper(self):
        prefix_mapper = {
            name: "model." + name
            for name, _ in self.model.model.named_children()
        }
        return WeightsMapper(
            orig_to_new_substr={"model.": "model.model."},
            orig_to_new_prefix=prefix_mapper,
        )

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, intermediate_tensors,
                                  inputs_embeds)
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)