column_parallel_linear.py 24.1 KB
Newer Older
Jee Jee Li's avatar
Jee Jee Li committed
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import torch
import torch.nn as nn
from transformers import PretrainedConfig

9
from vllm.config.lora import LoRAConfig
10
from vllm.distributed import tensor_model_parallel_all_gather
Jee Jee Li's avatar
Jee Jee Li committed
11
from vllm.distributed.utils import divide
12
from vllm.model_executor.custom_op import maybe_get_oot_by_class
13
14
15
16
17
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
)
Jee Jee Li's avatar
Jee Jee Li committed
18
19
20
21
22
23
24
from vllm.platforms import current_platform

from .base_linear import BaseLinearLayerWithLoRA
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace


def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
25
26
    """
    For `ColumnParallelLinearWithLoRA` or classes that inherit from
Jee Jee Li's avatar
Jee Jee Li committed
27
28
    `ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
    """
29
30
31
32
33
34
    assert (
        layer.n_slices
        == len(layer.lora_a_stacked)
        == len(layer.lora_b_stacked)
        == len(layer.output_slices)
    )
Jee Jee Li's avatar
Jee Jee Li committed
35
36
37
38
39
40
41
42

    output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)

    x = x.view(-1, x.shape[-1])
    output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape

    # Since communication is needed, the buffer is directly initialized as a
    # tensor rather than a tuple of tensor.
43
44
45
46
47
48
49
50
51
    local_lora_rank = layer.lora_a_stacked[0].shape[2]
    buffer_shape = (layer.n_slices, x.shape[0], local_lora_rank)
    # Under torch.compile, the local-rank-1 fully-sharded path can otherwise
    # get lowered to a reinterpret view with a non-canonical layout. The
    # Triton shrink op mutates this buffer in place and expects the standard
    # contiguous [slice, token, rank] stride contract.
    buffers = torch.empty_strided(
        buffer_shape,
        (x.shape[0] * local_lora_rank, local_lora_rank, 1),
Jee Jee Li's avatar
Jee Jee Li committed
52
53
54
        dtype=torch.float32,
        device=x.device,
    )
55
    buffers.zero_()
Jee Jee Li's avatar
Jee Jee Li committed
56

57
    shrunk_buffers: torch.Tensor | None = layer.punica_wrapper.add_shrink(
58
59
        buffers, x, layer.lora_a_stacked, 1.0
    )
Jee Jee Li's avatar
Jee Jee Li committed
60
61
62
63
64
65

    if not current_platform.can_update_inplace():
        buffers = shrunk_buffers

    buffers = tensor_model_parallel_all_gather(buffers)

66
    lora_output: torch.Tensor | None = layer.punica_wrapper.add_expand(
Jee Jee Li's avatar
Jee Jee Li committed
67
68
69
70
71
        output,
        buffers,
        layer.lora_b_stacked,
        layer.output_slices,
        offset_start=0,
72
73
        add_input=True,
    )
Jee Jee Li's avatar
Jee Jee Li committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

    if not current_platform.can_update_inplace():
        output = lora_output

    output = output.view(*out_orig_shape)
    # now have column partitioned and packed output
    return output


class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
    """

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__(base_layer)
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
97
        self.is_merged_col_linear = isinstance(base_layer, MergedColumnParallelLinear)
Jee Jee Li's avatar
Jee Jee Li committed
98
99
100
101
102
103
104
105
106
107
108
109
        self.output_size = self.base_layer.output_size_per_partition
        # There is only one LoRA layer
        self.n_slices = 1

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            shard_size = self.output_size // 2
110
            offset = lora_b.shape[0] // 2
Jee Jee Li's avatar
Jee Jee Li committed
111

112
113
114
115
116
117
118
119
            left_weight = lora_b[
                self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, :
            ]
            right_weight = lora_b[
                offset + self.tp_rank * shard_size : offset
                + (self.tp_rank + 1) * shard_size,
                :,
            ]
120
            lora_b = torch.cat([left_weight, right_weight], dim=0)
Jee Jee Li's avatar
Jee Jee Li committed
121
122
123
124
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            shard_size = self.output_size
125
126
            start_idx = self.tp_rank * shard_size
            end_idx = (self.tp_rank + 1) * shard_size
127
            lora_b = lora_b[start_idx:end_idx, :]
Jee Jee Li's avatar
Jee Jee Li committed
128
129
130
131
        return lora_b

    def forward(
        self, input_: torch.Tensor
132
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
Jee Jee Li's avatar
Jee Jee Li committed
133
134
135
136
137
138
139
140
141
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
142
        bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
Jee Jee Li's avatar
Jee Jee Li committed
143
144
145

        # Matrix multiply.
        output_parallel = self.apply(input_, bias)
146
        if self.base_layer.gather_output and self.tp_size > 1:
Jee Jee Li's avatar
Jee Jee Li committed
147
148
149
150
151
152
153
154
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel

        if not self.base_layer.return_bias:
            return output

155
        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
Jee Jee Li's avatar
Jee Jee Li committed
156
157
158
159
160
161
162
163
164
        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
165
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
166
    ) -> bool:
167
        if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
168
            return True
169
        if isinstance(source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)):
170
171
172
173
174
175
176
177
178
179
            if len(packed_modules_list) != 1:
                return False
            # Exclude layers with 3+ output sizes - those are handled by
            # MergedColumnParallelLinearVariableSliceWithLoRA since this
            # class's slice_lora_b assumes exactly 2 slices.
            return not (
                hasattr(source_layer, "output_sizes")
                and len(source_layer.output_sizes) >= 3
            )
        return False
Jee Jee Li's avatar
Jee Jee Li committed
180
181
182
183
184
185
186
187
188
189
190
191


class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (e.g. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(
192
        self, base_layer: MergedColumnParallelLinear | QKVParallelLinear
193
    ) -> None:
Jee Jee Li's avatar
Jee Jee Li committed
194
195
196
197
198
199
        super().__init__(base_layer)
        # There are two LoRA layers
        # the output_sizes in MergedColumnParallelLinear is not sharded by tp
        # we need to divide it by the tp_size to get correct slices size
        output_sizes = self.base_layer.output_sizes
        self.output_slices = tuple(
200
201
            divide(output_size, self.tp_size) for output_size in output_sizes
        )
Jee Jee Li's avatar
Jee Jee Li committed
202
        self.n_slices = len(self.output_slices)
203
        self.output_ids = (self.tp_rank,) * self.n_slices
Jee Jee Li's avatar
Jee Jee Li committed
204
205
206
207
208

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
209
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
210
211
    ) -> None:
        """
212
        The main reason for overriding this function is to enhance  code
Jee Jee Li's avatar
Jee Jee Li committed
213
214
215
216
217
        maintainability.
        """
        self.lora_config = lora_config

        lora_a_output_size_per_partition = (
218
219
220
221
            lora_config.max_lora_rank
            if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size)
        )
Jee Jee Li's avatar
Jee Jee Li committed
222
223
224
225
226
227
228
229
230

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_a_output_size_per_partition,
                self.input_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
231
232
233
            )
            for _ in range(self.n_slices)
        )
Jee Jee Li's avatar
Jee Jee Li committed
234
235
236
237
238
239
240
241
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                output_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
242
243
244
            )
            for output_size in self.output_slices
        )
Jee Jee Li's avatar
Jee Jee Li committed
245
246

    def slice_lora_a(
247
248
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
Jee Jee Li's avatar
Jee Jee Li committed
249
250
251
        return lora_a

    def slice_lora_b(
252
253
        self, lora_b: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
Jee Jee Li's avatar
Jee Jee Li committed
254
255
        sliced_lora_b = [None] * self.n_slices
        for i, (shard_id, shard_size) in enumerate(
256
257
            zip(self.output_ids, self.output_slices)
        ):
Jee Jee Li's avatar
Jee Jee Li committed
258
            if (lora_b_i := lora_b[i]) is not None:
259
260
261
                sliced_lora_b[i] = lora_b_i[
                    shard_size * shard_id : shard_size * (shard_id + 1), :
                ]
Jee Jee Li's avatar
Jee Jee Li committed
262
263
264
265
266
        return sliced_lora_b

    def set_lora(
        self,
        index: int,
267
268
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
Jee Jee Li's avatar
Jee Jee Li committed
269
270
271
272
273
274
275
276
277
278
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

        for i in range(self.n_slices):
            if (lora_a_i := lora_a[i]) is not None:
                self.lora_a_stacked[i][
279
280
                    index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1]
                ].copy_(lora_a_i, non_blocking=True)
Jee Jee Li's avatar
Jee Jee Li committed
281
282
            if (lora_b_i := lora_b[i]) is not None:
                self.lora_b_stacked[i][
283
284
                    index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
                ].copy_(lora_b_i, non_blocking=True)
Jee Jee Li's avatar
Jee Jee Li committed
285

286
287
288
289
290
291
292
293
294
295
296
297
    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
        # Effectively unsharded subclasses can safely reuse their custom
        # forward() implementation before applying the LoRA delta.
        if (
            self.tp_size == 1
            and type(self.base_layer) is not merged_cls
            and type(self.base_layer).forward is not merged_cls.forward
        ):
            return self._apply_base_forward(x)
        return _mcp_apply(x, bias, self)

Jee Jee Li's avatar
Jee Jee Li committed
298
299
300
301
302
303
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
304
        model_config: PretrainedConfig | None = None,
305
        decorate: bool = True,
Jee Jee Li's avatar
Jee Jee Li committed
306
    ) -> bool:
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
        if not isinstance(source_layer, merged_cls) or len(packed_modules_list) != 2:
            return False

        tp_size = getattr(source_layer, "tp_size", 1)
        if type(source_layer) is merged_cls:
            if not decorate:
                return True
            return not lora_config.fully_sharded_loras or tp_size == 1

        # Only support effectively unsharded subclasses here. Sharded
        # subclasses may have custom communication semantics that the generic
        # merged-column LoRA path does not know how to preserve.
        return tp_size == 1
Jee Jee Li's avatar
Jee Jee Li committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337


class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """
    ColumnParallelLinear layer that is specifically designed for
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
    only contains a single LoRA within their qkv_proj layer.

    During inference with Tensor Parallel, the weights of lora_b
    must be accurately partitioned according to the respective ranks.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
338
339
340
341
342
343
344
345
346
347
        self.q_proj_total_size = (
            self.base_layer.total_num_heads * self.base_layer.head_size
        )
        self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
        self.kv_proj_shard_size = (
            self.base_layer.num_kv_heads * self.base_layer.head_size
        )
        self.kv_proj_total_size = (
            self.base_layer.total_num_kv_heads * self.base_layer.head_size
        )
Jee Jee Li's avatar
Jee Jee Li committed
348
349
350
351
        # There is only one LoRA layer
        self.n_slices = 1

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
352
353
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
354
355
356
357
358
        lora_b_q = lora_b[
            self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
            * (self.q_shard_id + 1),
            :,
        ]
Jee Jee Li's avatar
Jee Jee Li committed
359
        k_offset = self.q_proj_total_size
360
361
362
363
364
        lora_b_k = lora_b[
            k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
            + self.kv_proj_shard_size * (self.kv_shard_id + 1),
            :,
        ]
Jee Jee Li's avatar
Jee Jee Li committed
365
        v_offset = k_offset + self.kv_proj_total_size
366
367
368
369
370
        lora_b_v = lora_b[
            v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
            + self.kv_proj_shard_size * (self.kv_shard_id + 1),
            :,
        ]
371
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
Jee Jee Li's avatar
Jee Jee Li committed
372
373
374
375
        return lora_b

    @classmethod
    @_not_fully_sharded_can_replace
376
377
378
379
380
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
381
        model_config: PretrainedConfig | None = None,
382
383
    ) -> bool:
        return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
Jee Jee Li's avatar
Jee Jee Li committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401


class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
    """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)

402
403
404
405
        self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
        self.kv_proj_shard_size = (
            self.base_layer.num_kv_heads * self.base_layer.head_size
        )
Jee Jee Li's avatar
Jee Jee Li committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas

        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
        self.output_ids = (
            self.q_shard_id,
            self.kv_shard_id,
            self.kv_shard_id,
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
424
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
425
426
    ) -> None:
        """
427
        The main reason for overloading this function is to handle inconsistent
Jee Jee Li's avatar
Jee Jee Li committed
428
429
430
431
432
433
434
435
436
437
438
        weight dimensions in qkv lora.
        """
        super().create_lora_weights(max_loras, lora_config, model_config)

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
439
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
440
    ) -> bool:
441
        return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
Jee Jee Li's avatar
Jee Jee Li committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462


# These following layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.


class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
    """
    Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
    # their `lora_a` and `lora_b` have different sharding patterns. After
    # completing the `lora_a` GEMM , a gather operation is performed.
    # Therefore, the sharding of `lora_a` only needs to correspond with the
    # gather operation.
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_a_stacked[0].shape[2]
463
        start_idx = self.tp_rank * shard_size
464
        lora_a = lora_a[start_idx : start_idx + shard_size, :]
Jee Jee Li's avatar
Jee Jee Li committed
465
466
        return lora_a

467
    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
468
469
470
471
472
473
474
475
476
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
477
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
478
479
480
481
482
483
484
485
486
487
488
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


489
class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA):
Jee Jee Li's avatar
Jee Jee Li committed
490
491
492
493
494
495
496
497
    """
    Differs from MergedColumnParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(
498
499
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
500
        # NOTE: lora_a contains 2 subloras, and each sublora could be None.
Jee Jee Li's avatar
Jee Jee Li committed
501
502
503
        output_shard_size = self.lora_a_stacked[0].shape[2]
        output_start_idx = self.tp_rank * output_shard_size
        lora_a = [
504
505
506
507
508
509
            lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
            if lora_a[0] is not None
            else None,
            lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
            if lora_a[1] is not None
            else None,
Jee Jee Li's avatar
Jee Jee Li committed
510
511
512
        ]
        return lora_a

513
    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
514
515
516
517
518
519
520
521
522
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
523
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
    """
    Differs from QKVParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_a_stacked[0].shape[2]
545
        start_idx = self.tp_rank * shard_size
546
        lora_a = lora_a[start_idx : start_idx + shard_size, :]
Jee Jee Li's avatar
Jee Jee Li committed
547
548
        return lora_a

549
    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
550
551
552
553
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
554
555
556
557
558
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
559
        model_config: PretrainedConfig | None = None,
560
    ) -> bool:
Jee Jee Li's avatar
Jee Jee Li committed
561
562
563
564
565
566
567
568
569
570
571
572
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )


class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
    """
573
    Differs from MergedQKVParallelLinearWithLoRA by slicing the
Jee Jee Li's avatar
Jee Jee Li committed
574
575
576
577
578
579
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(
580
581
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
Jee Jee Li's avatar
Jee Jee Li committed
582
583
584
585
        # NOTE: lora_a contains 3 subloras, and each sublora could be None.
        shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
        start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
        lora_a = [
586
587
588
589
590
591
592
593
594
            lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
            if lora_a[0] is not None
            else None,
            lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
            if lora_a[1] is not None
            else None,
            lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
            if lora_a[2] is not None
            else None,
Jee Jee Li's avatar
Jee Jee Li committed
595
596
597
        ]
        return lora_a

598
    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
Jee Jee Li's avatar
Jee Jee Li committed
599
600
601
602
603
604
605
606
607
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
608
        model_config: PretrainedConfig | None = None,
Jee Jee Li's avatar
Jee Jee Li committed
609
610
611
612
613
614
615
616
617
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639


class MergedColumnParallelLinearVariableSliceWithLoRA(
    MergedColumnParallelLinearWithLoRA
):
    """MergedColumnParallelLinear with variable number of slices (3+).

    This handles cases where the checkpoint has a single weight for the whole
    module (not split into slices), but the layer itself has multiple slices.
    """

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # Support MergedColumnParallelLinear with 3 or more slices
        # (2 slices are handled by MergedColumnParallelLinearWithLoRA)
640
641
642
        if not isinstance(
            source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)
        ):
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
            return False

        # If packed_modules_list has 3+ items, use this class
        if len(packed_modules_list) >= 3:
            return True

        # If packed_modules_list has exactly 2 items, let
        # MergedColumnParallelLinearWithLoRA handle it
        if len(packed_modules_list) == 2:
            return False

        # If packed_modules_list is empty or has 1 item,
        # check the layer's output_sizes.
        # This handles cases where the checkpoint has a single weight
        # but the layer has multiple slices (3+)
        return (
            hasattr(source_layer, "output_sizes")
            and len(source_layer.output_sizes) >= 3
        )

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Override to handle single tensor weights
        that need to be split into slices."""
        self.reset_lora(index)

        # Handle case where checkpoint has single tensor weights
        # lora_a shape: (rank, input_size) - same for all slices, duplicate it
        if isinstance(lora_a, torch.Tensor):
            lora_a = [lora_a] * self.n_slices

        # lora_b shape: (total_output_size, rank) -
        # split along dim 0 based on output_sizes
        if isinstance(lora_b, torch.Tensor):
            output_sizes = self.base_layer.output_sizes
            lora_b_list = []
            start_idx = 0
            for output_size in output_sizes:
                end_idx = start_idx + output_size
                lora_b_list.append(lora_b[start_idx:end_idx, :])
                start_idx = end_idx
            lora_b = lora_b_list

        # Now call parent's set_lora which expects lists
        super().set_lora(index, lora_a, lora_b)