base_linear.py 11.7 KB
Newer Older
Jee Jee Li's avatar
Jee Jee Li committed
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import torch
from transformers import PretrainedConfig

8
9
from vllm import envs
from vllm.config import get_current_vllm_config
10
from vllm.config.lora import LoRAConfig
Jee Jee Li's avatar
Jee Jee Li committed
11
from vllm.distributed.utils import divide
12
13
14
15
16
from vllm.forward_context import (
    ForwardContext,
    get_forward_context,
    is_forward_context_available,
)
17
18
19
20
21
22
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    LinearBase,
    ReplicatedLinear,
    RowParallelLinear,
)
Jee Jee Li's avatar
Jee Jee Li committed
23
from vllm.platforms import current_platform
24
25
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.utils.torch_utils import direct_register_custom_op
Jee Jee Li's avatar
Jee Jee Li committed
26
27
28
29

from .base import BaseLayerWithLoRA
from .utils import _get_lora_device

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
if envs.VLLM_LORA_ENABLE_DUAL_STREAM:
    _lora_aux_cuda_stream: torch.cuda.Stream | None = None

    def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None:
        global _lora_aux_cuda_stream
        if _lora_aux_cuda_stream is None and current_platform.is_cuda_alike():
            _lora_aux_cuda_stream = torch.cuda.Stream()
        return _lora_aux_cuda_stream

    def lora_linear_async(
        layer_name: str,
        output_size: int,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        forward_context: ForwardContext = get_forward_context()
        self = forward_context.no_compile_layers[layer_name]
        return self._apply_async_impl(x, bias)

    def lora_linear_async_fake(
        layer_name: str,
        output_size: int,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # The real function reshapes output back to the original 3D shape
        # when the input has an extra batch dimension (transformers backend).
        if x.ndim == 3:
            return torch.empty(
                (x.size(0), x.size(1), output_size),
                device=x.device,
                dtype=x.dtype,
            )
        return torch.empty(
            (x.size(0), output_size),
            device=x.device,
            dtype=x.dtype,
        )

    direct_register_custom_op(
        op_name="lora_linear_async",
        op_func=lora_linear_async,
        fake_impl=lora_linear_async_fake,
    )

Jee Jee Li's avatar
Jee Jee Li committed
75
76
77
78

class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: LinearBase):
        super().__init__()
79
80

        self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM
Jee Jee Li's avatar
Jee Jee Li committed
81
82
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
83
84
85
        # Ensure tp_size and tp_rank consistency with the base_layer.
        self.tp_size = self.base_layer.tp_size
        self.tp_rank = self.base_layer.tp_rank
Jee Jee Li's avatar
Jee Jee Li committed
86
        self.device = _get_lora_device(self.base_layer)
87
        self._init_lora_stream_context()
Jee Jee Li's avatar
Jee Jee Li committed
88
89
90
91
        self.output_slices: tuple[int, ...]
        self.output_size: int
        self.n_slices: int

92
93
94
95
96
97
98
99
100
101
102
103
104
105
    def _init_lora_stream_context(self) -> None:
        if not self._enable_aux_cuda_stream:
            return
        vllm_config = get_current_vllm_config()
        self._lora_stream = _get_lora_aux_cuda_stream()
        assert current_platform.is_cuda_alike()
        self._events = [torch.cuda.Event(), torch.cuda.Event()]
        # lora_linear avoids prefix conflicts with the base layer
        self.layer_name = self.base_layer.prefix + ".lora_linear_async"
        compilation_config = vllm_config.compilation_config
        if self.layer_name in compilation_config.static_forward_context:
            raise ValueError("Duplicate layer name: {}".format(self.layer_name))
        compilation_config.static_forward_context[self.layer_name] = self

Jee Jee Li's avatar
Jee Jee Li committed
106
107
108
109
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
110
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
111
112
113
114
115
116
117
    ) -> None:
        self.lora_config = lora_config
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
118
119
120
121
122
            lora_a_out_size = (
                lora_config.max_lora_rank
                if not lora_config.fully_sharded_loras
                else divide(lora_config.max_lora_rank, self.tp_size)
            )
Jee Jee Li's avatar
Jee Jee Li committed
123
124
125
126
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
127
128
129
130
131
            lora_b_out_size = (
                self.output_size
                if not lora_config.fully_sharded_loras
                else divide(self.output_size, self.tp_size)
            )
Jee Jee Li's avatar
Jee Jee Li committed
132
133
134
135
136
137
138
139
140
141
142
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_a_out_size,
                self.input_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
143
144
145
            )
            for _ in range(self.n_slices)
        )
Jee Jee Li's avatar
Jee Jee Li committed
146
147
148
149
150
151
152
153
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
154
155
156
157
            )
            for _ in range(self.n_slices)
        )
        self.output_slices = (self.lora_b_stacked[0].shape[2],)
Jee Jee Li's avatar
Jee Jee Li committed
158
159
160
161
162
163
164
165
166

    def reset_lora(self, index: int):
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0

    def set_lora(
        self,
        index: int,
167
168
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
Jee Jee Li's avatar
Jee Jee Li committed
169
170
171
172
173
    ):
        # Except for QKVParallelLinearWithLoRA and
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
174
175
        assert isinstance(lora_a, torch.Tensor)
        assert isinstance(lora_b, torch.Tensor)
176
177
178
        assert (
            len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
        )
Jee Jee Li's avatar
Jee Jee Li committed
179
180
181
182
183
184

        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

185
186
187
188
189
190
        self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
            lora_a, non_blocking=True
        )
        self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
            lora_b, non_blocking=True
        )
Jee Jee Li's avatar
Jee Jee Li committed
191

192
    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
193
194
195
196
197
198
199
200
201
202
203
204
        # is_forward_context_available for tower modules
        if self._enable_aux_cuda_stream and is_forward_context_available():
            output_size = sum(self.output_slices)
            return torch.ops.vllm.lora_linear_async(
                self.layer_name, output_size, x, bias
            )
        else:
            return self._apply_sync(x, bias)

    def _apply_sync(
        self, x: torch.Tensor, bias: torch.Tensor | None = None
    ) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
205
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
206
        return self._apply_lora_to_output(x, output)
Jee Jee Li's avatar
Jee Jee Li committed
207

208
209
210
211
212
213
214
215
    def _apply_base_forward(self, x: torch.Tensor) -> torch.Tensor:
        base_output = self.base_layer(x)
        output = base_output[0] if isinstance(base_output, tuple) else base_output
        return self._apply_lora_to_output(x, output)

    def _apply_lora_to_output(
        self, x: torch.Tensor, output: torch.Tensor
    ) -> torch.Tensor:
216
217
218
        original_shape = output.shape if output.ndim == 3 else None

        # In transformers backend, x and output have extra batch dimension like
Jee Jee Li's avatar
Jee Jee Li committed
219
220
221
222
223
224
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

225
        lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
226
            output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
227
        )
Jee Jee Li's avatar
Jee Jee Li committed
228
229
230
        if not current_platform.can_update_inplace():
            output = lora_output

231
        # Reshape the flattened output back to its original shape,
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        # as some MM encoders cannot handle flattened inputs.
        if original_shape is not None:
            output = output.reshape(original_shape)

        return output

    def _apply_async_impl(
        self, x: torch.Tensor, bias: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Forward pass with base linear and LoRA on separate CUDA streams
        for overlap, using maybe_execute_in_parallel.
        Base layer runs on default stream; LoRA runs on aux stream.
        """
        assert envs.VLLM_LORA_ENABLE_DUAL_STREAM
        assert x.ndim in (2, 3)
        num_tokens = x.size(0) if x.ndim == 2 else x.size(1)
        output_size = sum(self.output_slices)

        def base_fn() -> torch.Tensor:
            return self.base_layer.quant_method.apply(self.base_layer, x, bias)

        def lora_fn() -> torch.Tensor:
            # Must be zeros, not empty: _lora_expand_kernel exits early (without
            # writing) when lora_id == -1 (no active LoRA). If uninitialized,
            # output.add_(lora_result) below would corrupt the base output.
            lora_output = torch.zeros(
                (num_tokens, output_size),
                device=self.device,
                dtype=x.dtype,
            )

            # Flatten the batch dimension for the transformers backend
            # (which uses shape (1, seq_len, hidden)), matching _apply_sync.
            x_2d = x.flatten(0, 1) if x.ndim == 3 else x
            self.punica_wrapper.add_lora_linear(
                lora_output,
                x_2d,
                self.lora_a_stacked,
                self.lora_b_stacked,
                1.0,
                self.output_slices,
                add_inputs=False,
            )
            return lora_output

        output, lora_result = maybe_execute_in_parallel(
            base_fn,
            lora_fn,
            self._events[0],
            self._events[1],
            self._lora_stream,
        )

        original_shape = output.shape if output.ndim == 3 else None

        # In transformers backend, x and output have extra batch dimension like
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

        output.add_(lora_result)

        # Reshape the flattened output back to its original shape,
298
299
300
301
        # as some MM encoders cannot handle flattened inputs.
        if original_shape is not None:
            output = output.reshape(original_shape)

Jee Jee Li's avatar
Jee Jee Li committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        return output

    @property
    def weight(self) -> torch.Tensor:
        # unquantizedLinear
        if hasattr(self.base_layer, "weight"):
            return self.base_layer.weight
        # Compressed Tensor
        elif hasattr(self.base_layer, "weight_packed"):
            return self.base_layer.weight_packed
        # GPTQ/AWQ
        elif hasattr(self.base_layer, "qweight"):
            return self.base_layer.qweight
        # marlin
        elif hasattr(self.base_layer, "B"):
            return self.base_layer.B
        else:
            raise ValueError(f"Unsupported base layer: {self.base_layer}")

    @property
322
    def bias(self) -> torch.Tensor | None:
Jee Jee Li's avatar
Jee Jee Li committed
323
324
325
326
        if hasattr(self.base_layer, "bias"):
            return self.base_layer.bias
        else:
            return None