transformers.py 20 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
18
from collections.abc import Iterable
19
from contextlib import nullcontext
20
from typing import Literal, Optional, Union
21

22
import regex as re
23
24
import torch
from torch import nn
25
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
26
27
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

28
from vllm.attention import Attention
29
from vllm.compilation.decorators import support_torch_compile
30
31
32
33
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
                         ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
34
35
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
36
                                               ReplicatedLinear,
37
38
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
41
42
43
44
45
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

46
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
47
48
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
                    is_pp_missing_parameter,
49
                    make_empty_intermediate_tensors_factory, maybe_prefix)
50
51
52
53
54
55
56
57
58
59
60
61

logger = init_logger(__name__)


def vllm_flash_attention_forward(
        # Transformers args
        module: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: torch.Tensor,
        # Transformers kwargs
62
        scaling: Optional[float] = None,
63
        # vLLM kwargs
64
        attention_instances: Optional[dict[Attention]] = None,
65
66
67
68
69
70
71
        **kwargs):
    self_attn = attention_instances[module.layer_idx]
    if scaling is not None:
        self_attn.impl.scale = float(scaling)
    hidden = query.shape[-2]
    query, key, value = (x.transpose(1, 2) for x in (query, key, value))
    query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
72
    return self_attn.forward(query, key, value), None
73
74
75
76
77


ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


78
79
80
81
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


82
def replace_linear_class(
83
84
85
    linear: nn.Linear, style: Literal["colwise", "rowwise"],
    quant_config: QuantizationConfig
) -> Union[ColumnParallelLinear, RowParallelLinear]:
86
    """
87
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.
88

89
90
91
92
93
94
    Args:
        linear (nn.Linear): `nn.Linear` to be replaced.
        style (str): Tensor parallel style of the new linear, e.g. "colwise".
        quant_config (QuantConfig): Quantization config for the new linear.
    Returns:
        Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
95
96
97
98
99
100
    """

    if not isinstance(style, str):
        raise ValueError(
            f"Unsupported parallel style type {type(style)}, expected str")

101
102
103
    vllm_linear_cls = {
        "colwise": ColumnParallelLinear,
        "rowwise": RowParallelLinear,
104
    }.get(style, ReplicatedLinear)
105

106
    return vllm_linear_cls(
107
108
109
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
110
        quant_config=quant_config,
111
        return_bias=False,
112
113
    )

114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class ConfigOverride:
    """Context manager to temporarily override config attributes."""

    def __init__(self, config: PretrainedConfig, **kwargs):
        self.config = config
        self.kwargs = kwargs
        self.kwargs_original = {}
        self.kwargs_delete = set()

    def __enter__(self):
        """Override config attributes."""
        for key, value in self.kwargs.items():
            if not hasattr(self.config, key):
                self.kwargs_delete.add(key)
            self.kwargs_original[key] = getattr(self.config, key, None)
            setattr(self.config, key, value)
        return self.config

    def __exit__(self, exc_type, exc_value, traceback):
        """Restore original config attributes."""
        for key, value in self.kwargs_original.items():
            if key in self.kwargs_delete:
                delattr(self.config, key)
            else:
                setattr(self.config, key, value)


142
class TransformersModel(nn.Module):
143

144
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
145
146
147
        super().__init__()
        logger.info("Using Transformers backend.")

148
149
150
151
152
153
        config: PretrainedConfig = vllm_config.model_config.hf_config
        cache_config: CacheConfig = vllm_config.cache_config
        device_config: DeviceConfig = vllm_config.device_config
        model_config: ModelConfig = vllm_config.model_config
        parallel_config: ParallelConfig = vllm_config.parallel_config
        quant_config: QuantizationConfig = vllm_config.quant_config
154

155
        self.config = config
156
157
158
159
160
161
162
163
164
165
166
        self.cache_config = cache_config
        self.device_config = device_config
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.quant_config = quant_config

        self.pp_group = get_pp_group()
        self.pp_size = self.pp_group.world_size
        self.pp_rank = self.pp_group.rank_in_group
        self.tp_size = get_tensor_model_parallel_world_size()

167
168
169
170
171
172
173
174
175
        # vLLM handles interleaved sliding window attention by creating a new
        # interleaved_sliding_window attribute and deleting the sliding_window
        # attribute. This breaks the constructors in Transformers so we
        # temporarily add the attribute back to construct the model.
        config_override = nullcontext()
        if hasattr(config, "interleaved_sliding_window"):
            config_override = ConfigOverride(
                config, sliding_window=config.interleaved_sliding_window)

176
        # Use meta device to delay allocating GPU tensors
177
        with torch.device("meta"), config_override:
178
179
180
            # FIXME(Isotr0py): We need to refactor this part in the future to
            # avoid registering an extra model layer, otherwise we will need a
            # weights mapper to rename weights.
181
182
183
184
185
186
            self.model: PreTrainedModel = AutoModel.from_config(
                config,
                attn_implementation="vllm",
                torch_dtype=model_config.dtype,
                trust_remote_code=model_config.trust_remote_code,
            )
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        self.pipeline_parallel()
        self.tensor_parallel()

        # Input embeddings
        if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
                    config.vocab_size,
                    config.hidden_size,
                    org_num_embeddings=config.vocab_size,
                    quant_config=quant_config,
                ))

        # Attention layers
        self.attention_instances = self.create_attention_instances()

        # Initialize buffers (e.g. rotary embedding inverse frequency)
        self.init_buffers(self.model)

207
        # Initialize any parameters that have not had their modules replaced
208
209
        self.init_parameters(self.model)

210
211
212
213
214
215
216
217
218
219
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_size <= 1:
            return
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        if not self.model.supports_pp_plan:
            raise ValueError(
                f"{type(self.model)} does not support pipeline parallel yet!")

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            if isinstance(getattr(self.model, name), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
                "in the base model are not supported yet!")
        if module_list_idx is None:
            raise ValueError(
                f"Could not find `ModuleList` in {type(self.model)}")

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
            if self.pp_group.is_first_rank or (self.config.tie_word_embeddings
                                               and self.pp_group.is_last_rank):
                continue
            setattr(self.model, name, PPMissingLayer())

        # Module list
        start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers,
                                                self.pp_rank, self.pp_size)
        layers_name = pp_plan[module_list_idx]
        layers = getattr(self.model, layers_name)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
            layers[i] = PPMissingLayer(return_tuple=True)

        # Layers after module list
        for name in pp_plan[module_list_idx + 1:]:
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                setattr(self.model, name, PPMissingLayer())

    def tensor_parallel(self):
        """
        Apply the model's tensor parallelization plan.
        Currently only supports linear layers.
        """
269
270
271
272
        if not self.model.supports_tp_plan:
            if self.tp_size <= 1:
                return

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            raise ValueError(
                f"{type(self.model)} does not support tensor parallel yet!")

        tp_plan = self.model._tp_plan

        def _tensor_parallel(module: nn.Module, prefix: str = ""):
            for child_name, child_module in module.named_children():
                qual_name = maybe_prefix(prefix, child_name)
                for pattern, style in tp_plan.items():
                    if re.match(pattern, qual_name) and isinstance(
                            child_module, nn.Linear):
                        new_module = replace_linear_class(
                            child_module, style, self.quant_config)
                        setattr(module, child_name, new_module)
                        log_replacement(qual_name, child_module, new_module)
                else:
                    _tensor_parallel(child_module, prefix=qual_name)

        _tensor_parallel(self.model)

    def create_attention_instances(self) -> dict[int, Attention]:
        """
        Create `Attention` instances to inform KV cache allocation.
        """
        num_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        start, end = get_pp_indices(self.config.num_hidden_layers,
                                    self.pp_rank, self.pp_size)
303
304
305
306
307
308
309
310
311
312
313

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
            sliding_window = None
            if (hasattr(self.config, "interleaved_sliding_window")
                    and hasattr(self.config, "sliding_window_pattern")
                    and ((i + 1) % self.config.sliding_window_pattern > 0)):
                sliding_window = self.config.interleaved_sliding_window

            attention_instances[i] = Attention(
314
315
                num_heads=num_heads,
                head_size=head_size,
316
317
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
318
319
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
320
                cache_config=self.cache_config,
321
                quant_config=self.quant_config,
322
                per_layer_sliding_window=sliding_window,
323
                prefix=f"{i}.attn")
324
        return attention_instances
325

326
    def init_buffers(self, module: nn.Module):
327
        """
328
329
330
331
332
333
334
335
336
337
338
        If a `buffer` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```

        This means that:
        - `type(module)` is a class from `transformers`
        - This class is constructed using a `PretrainedConfig`
339
        """
340
341
        for name, buffer in module.named_buffers(recurse=False):
            if buffer.device == torch.device("meta"):
342
343
344
345
346
347
348
349
                if module == self.model:
                    logger.warning(
                        "To initialize buffers correctly, we instantiate the "
                        "parent module and and extract the value of the "
                        "buffer from it. In this case, the parent module is "
                        "the base model. Instantiating the entire model here "
                        "risks GPU OOM. Could this buffer be moved to a child "
                        "module?")
350
351
352
353
354
                new_buffer = getattr(type(module)(self.config), name)
                setattr(module, name, new_buffer)
        for child in module.children():
            self.init_buffers(child)

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    def init_parameters(self, module: nn.Module):
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(...)
        ```
        """
        for name, param in module.named_parameters(recurse=False):
            if param.device == torch.device("meta"):
                new_param = nn.Parameter(
                    torch.empty_like(param.data,
                                     device=self.device_config.device))
                setattr(module, name, new_param)
        for child in module.children():
            self.init_parameters(child)

374
375
376
    def get_input_embeddings(self) -> nn.Module:
        return self.model.get_input_embeddings()

377
378
    def forward(
        self,
379
        input_ids: Optional[torch.Tensor],
380
381
382
383
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
384
385
386
387
388
389
390
391
392
393
394
395
396
        if not get_pp_group().is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

        hidden_states = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
397
398
399
400
            use_cache=False,
            position_ids=positions[None, ...],
            attention_instances=self.attention_instances,
            return_dict=False)[0][0, ...]  # we remove batch dimension for now
401
402
403
404
405

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        return hidden_states
406

407
408
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
409
        params_dict = dict(self.named_parameters())
410

411
        loaded_params = set[str]()
412
        for name, loaded_weight in weights:
413
414
415
416
417
            # Use "model" instead of base_model_prefix because
            # the base model attribute in vLLM is always `model`
            if not name.startswith(prefix := "model."):
                name = prefix + name

418
419
            if is_pp_missing_parameter(name, self):
                continue
420
421
422
423
424
425
426
            if name in params_dict:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)
        return loaded_params
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466


@support_torch_compile
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
                              SupportsPP):
    embedding_padding_modules = ["lm_head"]
    embedding_modules = ["embed_tokens"
                         ]  # TODO transformers will have a util to get it

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: PretrainedConfig = vllm_config.model_config.hf_config
        quant_config: QuantizationConfig = vllm_config.quant_config

        self.config = config

        self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)

        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            if config.tie_word_embeddings:
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings())

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

467
    # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    # this makes thing complicated. We need to remove this mapper after refactor
    # `TransformersModel` in the future.
    @property
    def hf_to_vllm_mapper(self):
        prefix_mapper = {
            name: "model." + name
            for name, _ in self.model.model.named_children()
        }
        return WeightsMapper(
            orig_to_new_substr={"model.": "model.model."},
            orig_to_new_prefix=prefix_mapper,
        )

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, intermediate_tensors,
                                  inputs_embeds)
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)