layers.py 45.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
7
from typing import TYPE_CHECKING, Optional, Union, cast
8
9
10
11
12
13

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

14
from vllm.adapter_commons.layers import AdapterMapping
15
from vllm.config import LoRAConfig
16
17
18
19
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
20
                              tensor_model_parallel_all_reduce)
21
from vllm.distributed.utils import divide
22
# yapf: disable
23
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
24
                                               LinearBase,
25
                                               MergedColumnParallelLinear,
26
                                               QKVParallelLinear,
27
                                               ReplicatedLinear,
28
                                               RowParallelLinear)
29
# yapf: enable
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
32
from vllm.model_executor.layers.rotary_embedding import (
    LinearScalingRotaryEmbedding, RotaryEmbedding)
33
from vllm.model_executor.layers.vocab_parallel_embedding import (
34
    VocabParallelEmbedding)
35
from vllm.platforms import current_platform
36
37

if TYPE_CHECKING:
38
    from vllm.lora.punica_wrapper import PunicaWrapperBase
39
40


41
42
43
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
44
    # unquantizedLinear
45
46
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
47
48
49
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
50
    # GPTQ/AWQ
Jee Li's avatar
Jee Li committed
51
52
53
54
55
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
56
57
58
    # HQQ marlin
    elif hasattr(base_layer, "W_q"):
        return base_layer.W_q.device
Jee Li's avatar
Jee Li committed
59
60
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
61
62


63
64
65
66
67
68
69
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
70
71
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
72
73
74
75
76
77
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


78
@dataclass
79
class LoRAMapping(AdapterMapping):
80
    is_prefill: bool = False
81
82
83
84


class BaseLayerWithLoRA(nn.Module):

85
    def slice_lora_a(
86
87
        self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
88
89
90
        """Slice lora a if splitting for tensor parallelism."""
        ...

91
    def slice_lora_b(
92
93
        self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
94
95
96
        """Slice lora b if splitting with tensor parallelism."""
        ...

97
    def create_lora_weights(
98
99
100
101
102
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
103
104
105
106
107
108
109
110
111
112
113
114
115
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
116
        bias: Optional[torch.Tensor] = None,
117
118
119
120
121
122
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
123
        punica_wrapper,
124
    ):
125
        self.punica_wrapper: PunicaWrapperBase = punica_wrapper
126

127
    @classmethod
128
129
130
131
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
132
        packed_modules_list: list,
133
134
        model_config: Optional[PretrainedConfig],
    ) -> bool:
135
136
137
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

138
139
140
141
142
143

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
144
        self.embeddings_slice: Optional[tuple[int, int]]
145
        self.embeddings_weights: Optional[torch.Tensor]
146
147
148
149
150
151
152

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

153
        if self.base_layer.num_added_embeddings_per_partition > 0:
154
            # We can start adding lora weights
155
156
157
158
159
160
161
162
163
164
165
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
215
        bias: Optional[torch.Tensor] = None,
216
217
218
219
220
221
222
223
224
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
225
226
227
228
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ].copy_(embeddings_tensor, non_blocking=True)
229
230
231
232
233
234
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
235
                    self.embeddings_tensors.shape[2],
236
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
237
                assert self.embeddings_weights is not None
238
239
240
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
241
242
243
244
245
246
        added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
                                        1, 0)
        embeddings_indices = torch.narrow(
            self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))

        indices = embeddings_indices[1]
247
248
249
250
        full_lora_a_embeddings = F.embedding(
            x + indices,
            self.lora_a_stacked_2d,
        )
251
252
253
        indices = embeddings_indices[0]
        full_output = self.base_layer.forward(x +
                                              (indices * added_tokens_mask))
254
255
256
257
258
259
260
261

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
                full_output.shape[0] * full_output.shape[1], -1)
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
                full_lora_a_embeddings.shape[0] *
262
263
264
                full_lora_a_embeddings.shape[1],
                -1,
            )
265
266
267
268
269
270
271
272
273
274
275

        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_embedding(
                full_output,
                full_lora_a_embeddings,
                self.lora_b_stacked,
                add_input=True)

        if not current_platform.can_update_inplace():
            full_output = lora_output

276
277
        return full_output.view_as(full_output_org)

278
    @classmethod
279
280
281
282
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
283
        packed_modules_list: list,
284
285
        model_config: Optional[PretrainedConfig],
    ) -> bool:
286
287
        return type(source_layer) is VocabParallelEmbedding

288
289
290
291
    @property
    def weight(self):
        return self.base_layer.weight

292

293
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
294

295
    def __init__(self, base_layer: LinearBase):
296
297
298
299
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        self.device = _get_lora_device(self.base_layer)
300
        self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
301

302
        self.output_slices: tuple[int, ...]
303
304
305
        self.tp_size: int
        self.output_size: int
        self.n_slices: int
306
307
308
309
310
311
312
313

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        #
        if isinstance(self.base_layer, ReplicatedLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, ColumnParallelLinear):
            lora_a_out_size = (lora_config.max_lora_rank if
                               not lora_config.fully_sharded_loras else divide(
                                   lora_config.max_lora_rank, self.tp_size))
            lora_b_out_size = self.output_size

        elif isinstance(self.base_layer, RowParallelLinear):
            lora_a_out_size = lora_config.max_lora_rank
            lora_b_out_size = (self.output_size if
                               not lora_config.fully_sharded_loras else divide(
                                   self.output_size, self.tp_size))
        else:
            raise NotImplementedError

        self.lora_a_stacked = tuple(
            torch.zeros(
335
336
                max_loras,
                1,
337
338
                lora_a_out_size,
                self.input_size,
339
340
                dtype=lora_config.lora_dtype,
                device=self.device,
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
            ) for _ in range(self.n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_b_out_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            ) for _ in range(self.n_slices))
        if lora_config.bias_enabled:
            lora_bias_out_size = lora_b_out_size
            self.lora_bias_stacked = tuple(
                torch.zeros(
                    max_loras,
                    1,
                    lora_bias_out_size,
                    dtype=lora_config.lora_dtype,
                    device=self.device,
                ) for _ in range(self.n_slices))
        self.output_slices = (self.lora_b_stacked[0].shape[2], )
362
363

    def reset_lora(self, index: int):
364
365
366
367
368
        for s_index in range(self.n_slices):
            self.lora_a_stacked[s_index][index] = 0
            self.lora_b_stacked[s_index][index] = 0
            if self.lora_config.bias_enabled:
                # Make mypy happy
369
                self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
370
371
                                              self.lora_bias_stacked)
                self.lora_bias_stacked[s_index][index] = 0
372
373
374
375
376
377
378

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
379
        lora_bias: Optional[torch.Tensor] = None,
380
    ):
381
        # Except for QKVParallelLinearWithLoRA and
382
383
384
385
386
        # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
        # store weights in a tuple of size 1. These two layers will
        # override this function.
        assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
                self.n_slices == 1)
387

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        self.reset_lora(index)
        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)

        self.lora_a_stacked[0][index,
                               0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                   lora_a.T, non_blocking=True)
        self.lora_b_stacked[0][index,
                               0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                   lora_b.T, non_blocking=True)
        if lora_bias is not None:

403
            self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
404
405
406
407
408
409
410
411
                                          self.lora_bias_stacked)
            assert len(self.lora_bias_stacked)
            self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
                lora_bias.T, non_blocking=True)

    def apply(self,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
412
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
413
414
415
416
417
418
419
420

        # In transformers backend, x and output have extra batch dimension like
        # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
        # therefore we need to flatten the batch dimensions.
        if x.ndim == 3 and output.ndim == 3:
            output = output.flatten(0, 1)
            x = x.flatten(0, 1)

421
422
423
424
425
426
427
        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_linear(
                output, x, self.lora_a_stacked, self.lora_b_stacked,
                self.lora_bias_stacked, 1.0, self.output_slices)
        if not current_platform.can_update_inplace():
            output = lora_output

428
429
        return output

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    @property
    def weight(self) -> torch.Tensor:

        # unquantizedLinear
        if hasattr(self.base_layer, "weight"):
            return self.base_layer.weight
        # Compressed Tensor
        elif hasattr(self.base_layer, "weight_packed"):
            return self.base_layer.weight_packed
        # GPTQ/AWQ
        elif hasattr(self.base_layer, "qweight"):
            return self.base_layer.qweight
        # marlin
        elif hasattr(self.base_layer, "B"):
            return self.base_layer.B
        # HQQ marlin
        elif hasattr(self.base_layer, "W_q"):
            return self.base_layer.W_q
        else:
            raise ValueError(f"Unsupported base layer: {self.base_layer}")

    @property
    def bias(self) -> Optional[torch.Tensor]:
        if hasattr(self.base_layer, "bias"):
            return self.base_layer.bias
        else:
            return None

458
459
460
461
462
463
464
465
466
467

class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):

    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__(base_layer, )
        # To ensure interface compatibility, set to 1 always.
        self.tp_size = 1
        self.output_size = self.base_layer.output_size
        self.n_slices = 1

468
469
    def forward(
        self, input_: torch.Tensor
470
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
488
489
490
491

        if not self.base_layer.return_bias:
            return output

492
493
        return output, output_bias

494
495
    # ReplicatedLinear should always be replaced, regardless of the fully
    # sharded LoRAs setting, because it is, by definition, copied per GPU.
496
497
498
499
500
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
501
        packed_modules_list: list,
502
503
        model_config: Optional[PretrainedConfig],
    ) -> bool:
504
        return type(source_layer) is ReplicatedLinear
505
506


507
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
508
509
510
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
511
512
513
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
514
    """
515
516

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
517
        super().__init__(base_layer)
518
519
520
521
522
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = type(
            base_layer) is MergedColumnParallelLinear
523
        self.tp_size = get_tensor_model_parallel_world_size()
524
        self.output_size = self.base_layer.output_size_per_partition
525
526
        # There is only one LoRA layer
        self.n_slices = 1
527

528
529
530
531
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = self.output_size // 2
            offset = lora_b.shape[-1] // 2

            left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
                                 shard_size]
            right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
                                  (tp_rank + 1) * shard_size]
            lora_b = torch.cat([left_weight, right_weight], dim=1)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
548
            shard_size = self.output_size
549
550
551
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            lora_b = lora_b[:, start_idx:end_idx]
552
553
        return lora_b

554
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
555
        # TODO: Fix the slicing logic of bias.
556
557
558
        if bias is None:
            return bias
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
559
        shard_size = self.output_size
560
561
562
563
564
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        bias = bias[start_idx:end_idx]
        return bias

565
566
    def forward(
        self, input_: torch.Tensor
567
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
568
569
570
571
572
573
574
575
576
577
578
579
580
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
581
        output_parallel = self.apply(input_, bias)
582
583
584
585
586
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
587
588
589
590

        if not self.base_layer.return_bias:
            return output

591
592
593
594
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

595
    @classmethod
596
    @_not_fully_sharded_can_replace
597
598
599
600
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
601
        packed_modules_list: list,
602
603
        model_config: Optional[PretrainedConfig],
    ) -> bool:
604
605
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
606
607
            and len(packed_modules_list) == 1)

608
609
610
611
612
613
614
615
616
617

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

618
619
620
    def __init__(
        self, base_layer: Union[MergedColumnParallelLinear,
                                QKVParallelLinear]) -> None:
621
        super().__init__(base_layer)
622
        # There are two LoRA layers
623
624
625
626
627
628
629
630
631
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        # the output_sizes in MergedColumnParallelLinear is not sharded by tp
        # we need to divide it by the tp_size to get correct slices size
        output_sizes = self.base_layer.output_sizes
        self.output_slices = tuple(
            divide(output_size, self.tp_size) for output_size in output_sizes)
        self.n_slices = len(self.output_slices)
        self.output_ids = (self.tp_rank, ) * self.n_slices
632
633

    def create_lora_weights(
634
635
636
637
638
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
639
640
641
642
        """
        The main reason for overriding this function is to enhance  code 
        maintainability.
        """
643
        self.lora_config = lora_config
644

645
646
647
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
648
649
650
651
652

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
653
                lora_a_output_size_per_partition,
654
                self.input_size,
655
                dtype=lora_config.lora_dtype,
656
                device=self.device,
657
            ) for _ in range(self.n_slices))
658
659
660
661
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
662
                output_size,
663
664
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
665
                device=self.device,
666
            ) for output_size in self.output_slices)
667
        if lora_config.bias_enabled:
668
            self.lora_bias_stacked = tuple(
669
670
671
                torch.zeros(
                    max_loras,
                    1,
672
                    output_size,
673
674
                    dtype=lora_config.lora_dtype,
                    device=self.device,
675
                ) for output_size in self.output_slices)
676

677
    def slice_lora_a(
678
679
        self, lora_a: list[Union[torch.Tensor, None]]
    ) -> list[Union[torch.Tensor, None]]:
680
681
        return lora_a

682
    def slice_lora_b(
683
684
        self, lora_b: list[Union[torch.Tensor, None]]
    ) -> list[Union[torch.Tensor, None]]:
685
686
687
688
689
        for i, (shard_id, shard_size) in enumerate(
                zip(self.output_ids, self.output_slices)):
            if (lora_b_i := lora_b[i]) is not None:
                lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
                                     (shard_id + 1)]
690
691
        return lora_b

692
    def slice_bias(
693
694
        self, bias: list[Union[torch.Tensor,
                               None]]) -> list[Union[torch.Tensor, None]]:
695
696
697
698
699
        for i, (shard_id, shard_size) in enumerate(
                zip(self.output_ids, self.output_slices)):
            if (bias_i := bias[i]) is not None:
                bias[i] = bias_i[shard_size * shard_id:shard_size *
                                 (shard_id + 1)]
700
701
        return bias

702
703
704
705
706
707
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
708
        lora_bias: Optional[torch.Tensor] = None,
709
710
711
712
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
713
714
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
715
716
            if lora_bias is not None:
                lora_bias = self.slice_bias(lora_bias)
717

718
719
720
721
722
723
724
725
726
727
728
        for i in range(self.n_slices):
            if (lora_a_i := lora_a[i]) is not None:
                self.lora_a_stacked[i][
                    index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
                        lora_a_i.T, non_blocking=True)
            if (lora_b_i := lora_b[i]) is not None:
                self.lora_b_stacked[i][
                    index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
                        lora_b_i.T, non_blocking=True)

        if lora_bias is not None:
729
            self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
730
                                          self.lora_bias_stacked)
731
732
733
734
735
736
            for i in range(self.n_slices):
                if (lora_bias_i := lora_bias[i]) is not None:
                    self.lora_bias_stacked[i][index,
                                              0, :lora_bias_i.shape[0]].copy_(
                                                  lora_bias_i.T,
                                                  non_blocking=True)
737

738
    @classmethod
739
    @_not_fully_sharded_can_replace
740
741
742
743
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
744
        packed_modules_list: list,
745
746
        model_config: Optional[PretrainedConfig],
    ) -> bool:
747
        return (type(source_layer) is MergedColumnParallelLinear
748
                and len(packed_modules_list) == 2)
749

750

751
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
752
    """
753
    ColumnParallelLinear layer that is specifically designed for
754
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
755
    only contains a single LoRA within their qkv_proj layer.
756

757
    During inference with Tensor Parallel, the weights of lora_b
758
    must be accurately partitioned according to the respective ranks.
759

760
761
762
763
764
765
766
767
768
769
770
771
772
773
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)
774
775
        # There is only one LoRA layer
        self.n_slices = 1
776

777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        bias_q = bias[self.q_proj_shard_size *
                      self.q_shard_id:self.q_proj_shard_size *
                      (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        bias_k = bias[k_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        bias_v = bias[v_offset +
                      self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
        return bias

810
    @classmethod
811
    @_not_fully_sharded_can_replace
812
    def can_replace_layer(cls, source_layer: nn.Module,
813
                          lora_config: LoRAConfig, packed_modules_list: list,
814
                          model_config: Optional[PretrainedConfig]) -> bool:
815
        return type(source_layer) is QKVParallelLinear and len(
816
817
818
            packed_modules_list) == 1


819
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
820
    """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
821
822
823
824
825
826
827
828
829
830
831
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
832
833
834
835
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
836
837
838
839
840

        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
841
842
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
843

844
845
846
847
848
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
849
850
851
852
853
        self.output_ids = (
            self.q_shard_id,
            self.kv_shard_id,
            self.kv_shard_id,
        )
854

855
    def create_lora_weights(
856
        self,
857
858
859
860
861
862
863
864
865
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        """
        The main reason for overloading this function is to handle inconsistent 
        weight dimensions in qkv lora.
        """
        super().create_lora_weights(max_loras, lora_config, model_config)
866

867
    @classmethod
868
    @_not_fully_sharded_can_replace
869
870
871
872
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
873
        packed_modules_list: list,
874
875
        model_config: Optional[PretrainedConfig],
    ) -> bool:
876
        return (type(source_layer) is QKVParallelLinear
877
                and len(packed_modules_list) == 3)
878

879

880
881
882
883
884
#TODO: Implement this
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
    pass


885
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
886
887

    def __init__(self, base_layer: RowParallelLinear) -> None:
888
889
890
891
        super().__init__(base_layer)

        self.tp_size = get_tensor_model_parallel_world_size()
        # reset input_size
892
893
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
894

895
        self.tp_rank = get_tensor_model_parallel_rank()
896
897
        # There is only one LoRA layer.
        self.n_slices = 1
898

899
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
900

901
        shard_size = self.input_size
902
903
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
904
905
906
907
908
909
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

910
911
912
    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
        return bias

913
914
    def forward(
        self, input_: torch.Tensor
915
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
916
917
918
919
920
921
922
923
924
925
926
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
927
        # set up backprop all-reduce.
928
929
930
931
932
933
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
934
            input_parallel = splitted_input[self.tp_rank].contiguous()
935
936

        # Matrix multiply.
937
        output_parallel = self.apply(input_parallel)
938
939
940
941
942
943
944
945
946
947
948
949
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
950
951
952
953

        if not self.base_layer.return_bias:
            return output

954
955
        return output, output_bias

956
    @classmethod
957
    @_not_fully_sharded_can_replace
958
959
960
961
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
962
        packed_modules_list: list,
963
964
        model_config: Optional[PretrainedConfig],
    ) -> bool:
965
        return type(source_layer) is RowParallelLinear
966

967

968
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
969
970
971
972
973
974
975
976
977
978
979
980
981
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
982

983
984
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
985
                 sharded_to_full_mapping: Optional[list[int]]) -> None:
986
987
988
989
990
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
991
992
993
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
994

995
    @property
996
997
    def logits_as_input(self):
        return self.base_layer.logits_as_input
998

999
1000
1001
1002
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

1003
1004
1005
1006
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
1007
1008
1009
1010
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

1011
    @property
1012
1013
    def use_all_gather(self):
        return self.base_layer.use_all_gather
1014

1015
1016
1017
1018
1019
1020
1021
1022
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

1023
1024
1025
1026
    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

1027
1028
1029
1030
1031
1032
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1033
1034
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
1035
            raise ValueError("When using LoRA, vocab size must be "
1036
                             "32000 >= vocab_size <= 257024")
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1066
1067
1068
1069
1070
1071
1072
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1085
        bias: Optional[torch.Tensor] = None,
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
1096
1097
1098
1099
                index,
                :embeddings_tensor.shape[0],
                :embeddings_tensor.shape[1],
            ] = embeddings_tensor
1100
1101
1102
1103

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1104
        lm_head: VocabParallelEmbedding,
1105
        embedding_bias: Optional[torch.Tensor] = None,
1106
    ) -> Optional[torch.Tensor]:
1107
        # Get the logits for the next tokens.
1108
        logits = lm_head.quant_method.apply(lm_head, hidden_states)
1109
1110
        if embedding_bias is not None:
            logits += embedding_bias
1111
1112
1113
1114

        # Gather logits for TP
        logits = self.base_layer._gather_logits(logits)

1115
1116
1117
        if logits is None:
            return None

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
1147
1148
1149
1150
1151

        neg_inf, pos_inf = current_platform.get_infinity_values(
            lora_logits.dtype)

        lora_logits[-1] = neg_inf
1152
        lora_logits = lora_logits.mT
1153
        indices_padded = self.punica_wrapper.sampler_indices_padded
1154
1155
1156
1157

        if current_platform.is_tpu():
            indices_padded = indices_padded[:logits.size(0)]

1158
1159
1160
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1161
1162
1163
        ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf,
                                                      posinf=pos_inf,
                                                      neginf=neg_inf))
1164
1165
1166
1167
1168

        # HPU needs special handling to prune out dummy samples.
        if current_platform.is_hpu():
            lora_logits = lora_logits[:logits.shape[0], :]

1169
1170
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1171
               lora_logits.shape[1]] = lora_logits
1172

1173
1174
1175
1176
1177
1178
1179
        lora_output: Optional[
            torch.Tensor] = self.punica_wrapper.add_lora_logits(
                logits, hidden_states, self.lora_a_stacked,
                self.lora_b_stacked, 1.0)

        if not current_platform.can_update_inplace():
            logits = lora_output
1180
1181
1182
1183
1184
1185
1186
1187

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1188
    @classmethod
1189
1190
1191
1192
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
1193
        packed_modules_list: list,
1194
1195
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1196
1197
        # Special handling for the LogitsProcessor.
        return False
1198
1199


1200
class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
1201
1202
1203
1204
    """Implements RoPE-scaled embeddings with linear scaling for
    multiple LoRA adapters with a specialized kernel.

    Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
1205
    which can handle multi lora adapters in a specialized kernel.
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
    """

    def __init__(self, base_layer: RotaryEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    @property
    def scaling_factors(self):
        return self.base_layer.scaling_factors

    @property
    def rotary_dim(self):
        return self.base_layer.rotary_dim

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1226
1227
        scaling_factors = (list(lora_config.long_lora_scaling_factors)
                           if lora_config.long_lora_scaling_factors else [])
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
        base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
            self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
        scaling_factors = sorted(
            list(set([base_scaling_factor] + scaling_factors)))
        self.base_layer = LinearScalingRotaryEmbedding(
            self.base_layer.head_size,
            self.base_layer.rotary_dim,
            self.base_layer.max_position_embeddings,
            self.base_layer.base,
            self.base_layer.is_neox_style,
            scaling_factors,
            self.base_layer.dtype,
        )

    def reset_lora(self, index: int):
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
1251
        bias: Optional[torch.Tensor] = None,
1252
1253
1254
1255
1256
1257
1258
1259
    ):
        ...

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
1260
    ) -> tuple[torch.Tensor, torch.Tensor]:
1261
1262
1263
1264
        return self.base_layer(
            positions,
            query,
            key,
1265
1266
            offsets=self.punica_wrapper.long_lora_indices,
        )
1267
1268

    @property
1269
    def scaling_factor_to_offset(self) -> dict[float, int]:
1270
1271
1272
        return self.base_layer.scaling_factor_to_offset

    @classmethod
1273
1274
1275
1276
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
1277
        packed_modules_list: list,
1278
1279
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1280
        """Returns True if the layer can be replaced by this LoRA layer."""
1281
1282
        return (type(source_layer) is LinearScalingRotaryEmbedding
                or type(source_layer) is RotaryEmbedding)
1283
1284
1285

    def extra_repr(self) -> str:
        return self.base_layer.extra_repr()