base_linear.py 5.77 KB
Newer Older
Jee Jee Li's avatar
Jee Jee Li committed
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import torch
from transformers import PretrainedConfig

8
from vllm.config.lora import LoRAConfig
Jee Jee Li's avatar
Jee Jee Li committed
9
from vllm.distributed.utils import divide
10
11
12
13
14
15
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    LinearBase,
    ReplicatedLinear,
    RowParallelLinear,
)
Jee Jee Li's avatar
Jee Jee Li committed
16
17
18
19
20
21
22
23
24
25
26
from vllm.platforms import current_platform

from .base import BaseLayerWithLoRA
from .utils import _get_lora_device


class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: LinearBase):
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
27
28
29
        # Ensure tp_size and tp_rank consistency with the base_layer.
        self.tp_size = self.base_layer.tp_size
        self.tp_rank = self.base_layer.tp_rank
Jee Jee Li's avatar
Jee Jee Li committed
30
31
32
33
34
35
36
37
38
        self.device = _get_lora_device(self.base_layer)
        self.output_slices: tuple[int, ...]
        self.output_size: int
        self.n_slices: int

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
39
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
40
41
42
43
44
45
46
47
    ) -> None:
        self.lora_config = lora_config
        #
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
48
49
50
51
52
            lora_a_out_size = (
                lora_config.max_lora_rank
                if not lora_config.fully_sharded_loras
                else divide(lora_config.max_lora_rank, self.tp_size)
            )
Jee Jee Li's avatar
Jee Jee Li committed
53
54
55
56
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
57
58
59
60
61
            lora_b_out_size = (
                self.output_size
                if not lora_config.fully_sharded_loras
                else divide(self.output_size, self.tp_size)
            )
Jee Jee Li's avatar
Jee Jee Li committed
62
63
64
65
66
67
68
69
70
71
72
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_a_out_size,
                self.input_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
73
74
75
            )
            for _ in range(self.n_slices)
        )
Jee Jee Li's avatar
Jee Jee Li committed
76
77
78
79
80
81
82
83
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
84
85
86
87
            )
            for _ in range(self.n_slices)
        )
        self.output_slices = (self.lora_b_stacked[0].shape[2],)
Jee Jee Li's avatar
Jee Jee Li committed
88
89
90
91
92
93
94
95
96

    def reset_lora(self, index: int):
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0

    def set_lora(
        self,
        index: int,
97
98
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
Jee Jee Li's avatar
Jee Jee Li committed
99
100
101
102
103
    ):
        # Except for QKVParallelLinearWithLoRA and
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
104
105
        assert isinstance(lora_a, torch.Tensor)
        assert isinstance(lora_b, torch.Tensor)
106
107
108
        assert (
            len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
        )
Jee Jee Li's avatar
Jee Jee Li committed
109
110
111
112
113
114

        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

115
116
117
118
119
120
        self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
            lora_a, non_blocking=True
        )
        self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
            lora_b, non_blocking=True
        )
Jee Jee Li's avatar
Jee Jee Li committed
121

122
    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
123
124
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

125
126
127
        original_shape = output.shape if output.ndim == 3 else None

        # In transformers backend, x and output have extra batch dimension like
Jee Jee Li's avatar
Jee Jee Li committed
128
129
130
131
132
133
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

134
        lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
135
            output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
136
        )
Jee Jee Li's avatar
Jee Jee Li committed
137
138
139
        if not current_platform.can_update_inplace():
            output = lora_output

140
141
142
143
144
        # Reshape the flattened output back to its original shape,
        # as some MM encoders cannot handle flattened inputs.
        if original_shape is not None:
            output = output.reshape(original_shape)

Jee Jee Li's avatar
Jee Jee Li committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        return output

    @property
    def weight(self) -> torch.Tensor:
        # unquantizedLinear
        if hasattr(self.base_layer, "weight"):
            return self.base_layer.weight
        # Compressed Tensor
        elif hasattr(self.base_layer, "weight_packed"):
            return self.base_layer.weight_packed
        # GPTQ/AWQ
        elif hasattr(self.base_layer, "qweight"):
            return self.base_layer.qweight
        # marlin
        elif hasattr(self.base_layer, "B"):
            return self.base_layer.B
        else:
            raise ValueError(f"Unsupported base layer: {self.base_layer}")

    @property
165
    def bias(self) -> torch.Tensor | None:
Jee Jee Li's avatar
Jee Jee Li committed
166
167
168
169
        if hasattr(self.base_layer, "bias"):
            return self.base_layer.bias
        else:
            return None