layers.py 45.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
6
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
7
8
9
10
11
12

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

13
from vllm.adapter_commons.layers import AdapterMapping
14
from vllm.config import LoRAConfig
15
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
19
                              tensor_model_parallel_all_reduce)
20
from vllm.distributed.utils import divide
21
# yapf: disable
22
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
23
                                               LinearBase,
24
                                               MergedColumnParallelLinear,
25
                                               QKVParallelLinear,
26
                                               ReplicatedLinear,
27
                                               RowParallelLinear)
28
# yapf: enable
29
from vllm.model_executor.layers.logits_processor import LogitsProcessor
30
31
from vllm.model_executor.layers.rotary_embedding import (
    LinearScalingRotaryEmbedding, RotaryEmbedding)
32
from vllm.model_executor.layers.vocab_parallel_embedding import (
33
    VocabParallelEmbedding)
34
from vllm.platforms import current_platform
35
36

if TYPE_CHECKING:
37
    from vllm.lora.punica_wrapper import PunicaWrapperBase
38
39


40
41
42
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
43
    # unquantizedLinear
44
45
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
46
47
48
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
49
    # GPTQ/AWQ
Jee Li's avatar
Jee Li committed
50
51
52
53
54
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
55
56
57
    # HQQ marlin
    elif hasattr(base_layer, "W_q"):
        return base_layer.W_q.device
Jee Li's avatar
Jee Li committed
58
59
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
60
61


62
63
64
65
66
67
68
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
69
70
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
71
72
73
74
75
76
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


77
@dataclass
78
class LoRAMapping(AdapterMapping):
79
    is_prefill: bool = False
80
81
82
83


class BaseLayerWithLoRA(nn.Module):

84
85
86
    def slice_lora_a(
        self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
87
88
89
        """Slice lora a if splitting for tensor parallelism."""
        ...

90
91
92
    def slice_lora_b(
        self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
93
94
95
        """Slice lora b if splitting with tensor parallelism."""
        ...

96
    def create_lora_weights(
97
98
99
100
101
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
102
103
104
105
106
107
108
109
110
111
112
113
114
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
115
        bias: Optional[torch.Tensor] = None,
116
117
118
119
120
121
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
122
        punica_wrapper,
123
    ):
124
        self.punica_wrapper: PunicaWrapperBase = punica_wrapper
125

126
    @classmethod
127
128
129
130
131
132
133
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
134
135
136
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

137
138
139
140
141
142

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
143
144
        self.embeddings_slice: Optional[Tuple[int, int]]
        self.embeddings_weights: Optional[torch.Tensor]
145
146
147
148
149
150
151

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

152
        if self.base_layer.num_added_embeddings_per_partition > 0:
153
            # We can start adding lora weights
154
155
156
157
158
159
160
161
162
163
164
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
214
        bias: Optional[torch.Tensor] = None,
215
216
217
218
219
220
221
222
223
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
224
225
226
227
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ].copy_(embeddings_tensor, non_blocking=True)
228
229
230
231
232
233
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
234
                    self.embeddings_tensors.shape[2],
235
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
236
                assert self.embeddings_weights is not None
237
238
239
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
240
241
242
243
244
245
        added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
                                        1, 0)
        embeddings_indices = torch.narrow(
            self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))

        indices = embeddings_indices[1]
246
247
248
249
        full_lora_a_embeddings = F.embedding(
            x + indices,
            self.lora_a_stacked_2d,
        )
250
251
252
        indices = embeddings_indices[0]
        full_output = self.base_layer.forward(x +
                                              (indices * added_tokens_mask))
253
254
255
256
257
258
259
260

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
261
262
263
                full_lora_a_embeddings.shape[1],
                -1,
            )
264
265
266
267
        self.punica_wrapper.add_lora_embedding(full_output,
                                               full_lora_a_embeddings,
                                               self.lora_b_stacked,
                                               add_input=True)
268
269
        return full_output.view_as(full_output_org)

270
    @classmethod
271
272
273
274
275
276
277
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
278
279
        return type(source_layer) is VocabParallelEmbedding

280
281
282
283
    @property
    def weight(self):
        return self.base_layer.weight

284

285
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
286

287
    def __init__(self, base_layer: LinearBase):
288
289
290
291
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        self.device = _get_lora_device(self.base_layer)
292
293
294
295
296
297
        self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None

        self.output_slices: Tuple[int, ...]
        self.tp_size: int
        self.output_size: int
        self.n_slices: int
298
299
300
301
302
303
304
305

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        #
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
            lora_a_out_size = (lora_config.max_lora_rank if
                               not lora_config.fully_sharded_loras else divide(
                                   lora_config.max_lora_rank, self.tp_size))
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = (self.output_size if
                               not lora_config.fully_sharded_loras else divide(
                                   self.output_size, self.tp_size))
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
327
328
                max_loras,
                1,
329
330
                lora_a_out_size,
                self.input_size,
331
332
                dtype=lora_config.lora_dtype,
                device=self.device,
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            ) for _ in range(self.n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            ) for _ in range(self.n_slices))
        if lora_config.bias_enabled:
            lora_bias_out_size = lora_b_out_size
            self.lora_bias_stacked = tuple(
                torch.zeros(
                    max_loras,
                    1,
                    lora_bias_out_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ) for _ in range(self.n_slices))
        self.output_slices = (self.lora_b_stacked[0].shape[2], )
354
355

    def reset_lora(self, index: int):
356
357
358
359
360
361
362
363
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0
            if self.lora_config.bias_enabled:
                # Make mypy happy
                self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                              self.lora_bias_stacked)
                self.lora_bias_stacked[s_index][index] = 0
364
365
366
367
368
369
370

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
371
        lora_bias: Optional[torch.Tensor] = None,
372
    ):
373
        # Except for QKVParallelLinearWithLoRA and
374
375
376
377
378
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
        assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
                self.n_slices == 1)
379

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)

        self.lora_a_stacked[0][index,
                               0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                   lora_a.T, non_blocking=True)
        self.lora_b_stacked[0][index,
                               0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                   lora_b.T, non_blocking=True)
        if lora_bias is not None:

            self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
            assert len(self.lora_bias_stacked)
            self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
                lora_bias.T, non_blocking=True)

    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
404
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
405
406
407
408
409
410
411
412

        # In transformers backend, x and output have extra batch dimension like
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

413
414
415
416
        self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
                                            self.lora_b_stacked,
                                            self.lora_bias_stacked, 1.0,
                                            self.output_slices)
417
418
        return output

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    @property
    def weight(self) -> torch.Tensor:

        # unquantizedLinear
        if hasattr(self.base_layer, "weight"):
            return self.base_layer.weight
        # Compressed Tensor
        elif hasattr(self.base_layer, "weight_packed"):
            return self.base_layer.weight_packed
        # GPTQ/AWQ
        elif hasattr(self.base_layer, "qweight"):
            return self.base_layer.qweight
        # marlin
        elif hasattr(self.base_layer, "B"):
            return self.base_layer.B
        # HQQ marlin
        elif hasattr(self.base_layer, "W_q"):
            return self.base_layer.W_q
        else:
            raise ValueError(f"Unsupported base layer: {self.base_layer}")

    @property
    def bias(self) -> Optional[torch.Tensor]:
        if hasattr(self.base_layer, "bias"):
            return self.base_layer.bias
        else:
            return None

447
448
449
450
451
452
453
454
455
456

class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):

    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__(base_layer, )
        # To ensure interface compatibility, set to 1 always.
        self.tp_size = 1
        self.output_size = self.base_layer.output_size
        self.n_slices = 1

457
458
    def forward(
        self, input_: torch.Tensor
459
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
477
478
479
480

        if not self.base_layer.return_bias:
            return output

481
482
        return output, output_bias

483
484
    # ReplicatedLinear should always be replaced, regardless of the fully
    # sharded LoRAs setting, because it is, by definition, copied per GPU.
485
486
487
488
489
490
491
492
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
493
        return type(source_layer) is ReplicatedLinear
494
495


496
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
497
498
499
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
500
501
502
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
503
    """
504
505

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
506
        super().__init__(base_layer)
507
508
509
510
511
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = type(
            base_layer) is MergedColumnParallelLinear
512
        self.tp_size = get_tensor_model_parallel_world_size()
513
        self.output_size = self.base_layer.output_size_per_partition
514
515
        # There is only one LoRA layer
        self.n_slices = 1
516

517
518
519
520
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_size // 2
            offset = lora_b.shape[-1] // 2

            left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
                                 shard_size]
            right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
                                  (tp_rank + 1) * shard_size]
            lora_b = torch.cat([left_weight, right_weight], dim=1)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
537
            shard_size = self.output_size
538
539
540
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[:, start_idx:end_idx]
541
542
        return lora_b

543
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
544
        # TODO: Fix the slicing logic of bias.
545
546
547
        if bias is None:
            return bias
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
548
        shard_size = self.output_size
549
550
551
552
553
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

554
555
    def forward(
        self, input_: torch.Tensor
556
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
557
558
559
560
561
562
563
564
565
566
567
568
569
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
570
        output_parallel = self.apply(input_, bias)
571
572
573
574
575
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
576
577
578
579

        if not self.base_layer.return_bias:
            return output

580
581
582
583
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

584
    @classmethod
585
    @_not_fully_sharded_can_replace
586
587
588
589
590
591
592
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
593
594
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
595
596
            and len(packed_modules_list) == 1)

597
598
599
600
601
602
603
604
605
606

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

607
608
609
    def __init__(
        self, base_layer: Union[MergedColumnParallelLinear,
                                QKVParallelLinear]) -> None:
610
        super().__init__(base_layer)
611
        # There are two LoRA layers
612
613
614
615
616
617
618
619
620
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        # the output_sizes in MergedColumnParallelLinear is not sharded by tp
        # we need to divide it by the tp_size to get correct slices size
        output_sizes = self.base_layer.output_sizes
        self.output_slices = tuple(
            divide(output_size, self.tp_size) for output_size in output_sizes)
        self.n_slices = len(self.output_slices)
        self.output_ids = (self.tp_rank, ) * self.n_slices
621
622

    def create_lora_weights(
623
624
625
626
627
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
628
629
630
631
        """
        The main reason for overriding this function is to enhance  code 
        maintainability.
        """
632
        self.lora_config = lora_config
633

634
635
636
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
637
638
639
640
641

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
642
                lora_a_output_size_per_partition,
643
                self.input_size,
644
                dtype=lora_config.lora_dtype,
645
                device=self.device,
646
            ) for _ in range(self.n_slices))
647
648
649
650
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
651
                output_size,
652
653
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
654
                device=self.device,
655
            ) for output_size in self.output_slices)
656
        if lora_config.bias_enabled:
657
            self.lora_bias_stacked = tuple(
658
659
660
                torch.zeros(
                    max_loras,
                    1,
661
                    output_size,
662
663
                    dtype=lora_config.lora_dtype,
                    device=self.device,
664
                ) for output_size in self.output_slices)
665

666
667
668
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
669
670
        return lora_a

671
672
673
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
674
675
676
677
678
        for i, (shard_id, shard_size) in enumerate(
                zip(self.output_ids, self.output_slices)):
            if (lora_b_i := lora_b[i]) is not None:
                lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
                                     (shard_id + 1)]
679
680
        return lora_b

681
682
683
    def slice_bias(
        self, bias: List[Union[torch.Tensor,
                               None]]) -> List[Union[torch.Tensor, None]]:
684
685
686
687
688
        for i, (shard_id, shard_size) in enumerate(
                zip(self.output_ids, self.output_slices)):
            if (bias_i := bias[i]) is not None:
                bias[i] = bias_i[shard_size * shard_id:shard_size *
                                 (shard_id + 1)]
689
690
        return bias

691
692
693
694
695
696
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
697
        lora_bias: Optional[torch.Tensor] = None,
698
699
700
701
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
702
703
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
704
705
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)
706

707
708
709
710
711
712
713
714
715
716
717
        for i in range(self.n_slices):
            if (lora_a_i := lora_a[i]) is not None:
                self.lora_a_stacked[i][
                    index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
                        lora_a_i.T, non_blocking=True)
            if (lora_b_i := lora_b[i]) is not None:
                self.lora_b_stacked[i][
                    index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
                        lora_b_i.T, non_blocking=True)

        if lora_bias is not None:
718
719
            self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
                                          self.lora_bias_stacked)
720
721
722
723
724
725
            for i in range(self.n_slices):
                if (lora_bias_i := lora_bias[i]) is not None:
                    self.lora_bias_stacked[i][index,
                                              0, :lora_bias_i.shape[0]].copy_(
                                                  lora_bias_i.T,
                                                  non_blocking=True)
726

727
    @classmethod
728
    @_not_fully_sharded_can_replace
729
730
731
732
733
734
735
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
736
        return (type(source_layer) is MergedColumnParallelLinear
737
                and len(packed_modules_list) == 2)
738

739

740
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
741
    """
742
    ColumnParallelLinear layer that is specifically designed for
743
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
744
    only contains a single LoRA within their qkv_proj layer.
745

746
    During inference with Tensor Parallel, the weights of lora_b
747
    must be accurately partitioned according to the respective ranks.
748

749
750
751
752
753
754
755
756
757
758
759
760
761
762
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)
763
764
        # There is only one LoRA layer
        self.n_slices = 1
765

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        bias_q = bias[self.q_proj_shard_size *
                      self.q_shard_id:self.q_proj_shard_size *
                      (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        bias_k = bias[k_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        bias_v = bias[v_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
        return bias

799
    @classmethod
800
    @_not_fully_sharded_can_replace
801
802
803
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
804
        return type(source_layer) is QKVParallelLinear and len(
805
806
807
            packed_modules_list) == 1


808
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
809
    """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
810
811
812
813
814
815
816
817
818
819
820
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
821
822
823
824
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
825
826
827
828
829

        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
830
831
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
832

833
834
835
836
837
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
838
839
840
841
842
        self.output_ids = (
            self.q_shard_id,
            self.kv_shard_id,
            self.kv_shard_id,
        )
843

844
    def create_lora_weights(
845
        self,
846
847
848
849
850
851
852
853
854
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        """
        The main reason for overloading this function is to handle inconsistent 
        weight dimensions in qkv lora.
        """
        super().create_lora_weights(max_loras, lora_config, model_config)
855

856
    @classmethod
857
    @_not_fully_sharded_can_replace
858
859
860
861
862
863
864
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
865
        return (type(source_layer) is QKVParallelLinear
866
                and len(packed_modules_list) == 3)
867

868

869
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
870
871

    def __init__(self, base_layer: RowParallelLinear) -> None:
872
873
874
875
        super().__init__(base_layer)

        self.tp_size = get_tensor_model_parallel_world_size()
        # reset input_size
876
877
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
878

879
        self.tp_rank = get_tensor_model_parallel_rank()
880
881
        # There is only one LoRA layer.
        self.n_slices = 1
882

883
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
884

885
        shard_size = self.input_size
886
887
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
888
889
890
891
892
893
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

894
895
896
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        return bias

897
898
    def forward(
        self, input_: torch.Tensor
899
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # Set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
918
            input_parallel = splitted_input[self.tp_rank].contiguous()
919
920

        # Matrix multiply.
921
        output_parallel = self.apply(input_parallel)
922
923
924
925
926
927
928
929
930
931
932
933
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
934
935
936
937

        if not self.base_layer.return_bias:
            return output

938
939
        return output, output_bias

940
    @classmethod
941
    @_not_fully_sharded_can_replace
942
943
944
945
946
947
948
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
949
        return type(source_layer) is RowParallelLinear
950

951

952
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
953
954
955
956
957
958
959
960
961
962
963
964
965
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
966

967
968
969
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
                 sharded_to_full_mapping: Optional[List[int]]) -> None:
970
971
972
973
974
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
975
976
977
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
978

979
    @property
980
981
    def logits_as_input(self):
        return self.base_layer.logits_as_input
982

983
984
985
986
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

987
988
989
990
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
991
992
993
994
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

995
    @property
996
997
    def use_all_gather(self):
        return self.base_layer.use_all_gather
998

999
1000
1001
1002
1003
1004
1005
1006
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

1007
1008
1009
1010
    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

1011
1012
1013
1014
1015
1016
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1017
1018
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
1019
            raise ValueError("When using LoRA, vocab size must be "
1020
                             "32000 >= vocab_size <= 257024")
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1050
1051
1052
1053
1054
1055
1056
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1069
        bias: Optional[torch.Tensor] = None,
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
1080
1081
1082
1083
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ] = embeddings_tensor
1084
1085
1086
1087

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1088
        lm_head: VocabParallelEmbedding,
1089
        embedding_bias: Optional[torch.Tensor] = None,
1090
    ) -> Optional[torch.Tensor]:
1091
        # Get the logits for the next tokens.
1092
        logits = lm_head.quant_method.apply(lm_head, hidden_states)
1093
1094
        if embedding_bias is not None:
            logits += embedding_bias
1095
1096
1097
1098

        # Gather logits for TP
        logits = self.base_layer._gather_logits(logits)

1099
1100
1101
        if logits is None:
            return None

1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
        lora_logits[-1] = float("-inf")
        lora_logits = lora_logits.mT
1133
        indices_padded = self.punica_wrapper.sampler_indices_padded
1134
1135
1136
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1137
1138
1139
        ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
                                                      posinf=float("inf"),
                                                      neginf=float("-inf")))
1140
1141
1142
1143
1144

        # HPU needs special handling to prune out dummy samples.
        if current_platform.is_hpu():
            lora_logits = lora_logits[:logits.shape[0], :]

1145
1146
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1147
               lora_logits.shape[1]] = lora_logits
1148
1149
1150
1151
1152

        # LogitsProcessorWithLoRA always using bgmv
        self.punica_wrapper.add_lora_logits(logits, hidden_states,
                                            self.lora_a_stacked,
                                            self.lora_b_stacked, 1.0)
1153
1154
1155
1156
1157
1158
1159
1160

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1161
    @classmethod
1162
1163
1164
1165
1166
1167
1168
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1169
1170
        # Special handling for the LogitsProcessor.
        return False
1171
1172


1173
class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    """Implements RoPE-scaled embeddings with linear scaling for
    multiple LoRA adapters with a specialized kernel.

    Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
    which can handle multi lora adapters in a specialied kernel.
    """

    def __init__(self, base_layer: RotaryEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    @property
    def scaling_factors(self):
        return self.base_layer.scaling_factors

    @property
    def rotary_dim(self):
        return self.base_layer.rotary_dim

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1199
1200
        scaling_factors = (list(lora_config.long_lora_scaling_factors)
                           if lora_config.long_lora_scaling_factors else [])
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
            self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
        scaling_factors = sorted(
            list(set([base_scaling_factor] + scaling_factors)))
        self.base_layer = LinearScalingRotaryEmbedding(
            self.base_layer.head_size,
            self.base_layer.rotary_dim,
            self.base_layer.max_position_embeddings,
            self.base_layer.base,
            self.base_layer.is_neox_style,
            scaling_factors,
            self.base_layer.dtype,
        )

    def reset_lora(self, index: int):
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1224
        bias: Optional[torch.Tensor] = None,
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    ):
        ...

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.base_layer(
            positions,
            query,
            key,
1238
1239
            offsets=self.punica_wrapper.long_lora_indices,
        )
1240
1241
1242
1243
1244
1245

    @property
    def scaling_factor_to_offset(self) -> Dict[float, int]:
        return self.base_layer.scaling_factor_to_offset

    @classmethod
1246
1247
1248
1249
1250
1251
1252
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1253
        """Returns True if the layer can be replaced by this LoRA layer."""
1254
1255
        return (type(source_layer) is LinearScalingRotaryEmbedding
                or type(source_layer) is RotaryEmbedding)
1256
1257
1258

    def extra_repr(self) -> str:
        return self.base_layer.extra_repr()