layers.py 48.5 KB
Newer Older
1
2
3
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5
6
7
8
9
10

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

11
from vllm.adapter_commons.layers import AdapterMapping
12
from vllm.config import LoRAConfig
13
14
15
16
17
18
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce,
                              tensor_model_parallel_gather)
19
from vllm.distributed.utils import divide
20
from vllm.lora.punica import PunicaWrapper
21
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
22
                                               MergedColumnParallelLinear,
23
                                               QKVParallelLinear,
24
                                               ReplicatedLinear,
25
                                               RowParallelLinear)
26
from vllm.model_executor.layers.logits_processor import LogitsProcessor
27
28
from vllm.model_executor.layers.rotary_embedding import (
    LinearScalingRotaryEmbedding, RotaryEmbedding)
29
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    VocabParallelEmbedding)
31
32
33
34
35

if TYPE_CHECKING:
    pass


36
37
38
def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
Jee Li's avatar
Jee Li committed
39
    # unquantizedLinear
40
41
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
42
    # GPTQ/AWQ
Jee Li's avatar
Jee Li committed
43
44
45
46
47
48
49
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # marlin
    elif hasattr(base_layer, "B"):
        return base_layer.B.device
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")
50
51


52
53
54
55
56
57
58
def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
59
60
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
        condition = (not kwargs["lora_config"].fully_sharded_loras
61
62
63
64
65
66
                     if decorate else True)
        return can_replace(*args, **kwargs) and condition

    return dec


67
@dataclass
68
class LoRAMapping(AdapterMapping):
69
    is_prefill: bool = False
70
71
72
73


class BaseLayerWithLoRA(nn.Module):

74
75
76
    def slice_lora_a(
        self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
77
78
79
        """Slice lora a if splitting for tensor parallelism."""
        ...

80
81
82
    def slice_lora_b(
        self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
    ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
83
84
85
        """Slice lora b if splitting with tensor parallelism."""
        ...

86
    def create_lora_weights(
87
88
89
90
91
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
111
        punica_wrapper: PunicaWrapper,
112
    ):
113
        self.punica_wrapper: PunicaWrapper = punica_wrapper
114

115
    @classmethod
116
117
118
119
120
121
122
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
123
124
125
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

126
127
128
129
130
131

class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
132
133
        self.embeddings_slice: Optional[Tuple[int, int]]
        self.embeddings_weights: Optional[torch.Tensor]
134
        self.device = _get_lora_device(self.base_layer)
135
136
137
138
139
140
141

    def create_lora_weights(
            self,
            max_loras: int,
            lora_config: LoRAConfig,
            model_config: Optional[PretrainedConfig] = None) -> None:

142
        if self.base_layer.num_added_embeddings_per_partition > 0:
143
            # We can start adding lora weights
144
145
146
147
148
149
150
151
152
153
154
            self.embeddings_weights = self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:self.
                base_layer.num_org_embeddings_per_partition +
                self.base_layer.num_added_embeddings_per_partition]
            self.embeddings_slice = (
                self.base_layer.shard_indices.added_vocab_start_index -
                self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index -
                self.base_layer.org_vocab_size)
            self.base_layer.weight.data[
                self.base_layer.num_org_embeddings_per_partition:].fill_(0)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.embeddings_tensors = torch.zeros(
            (
                max_loras,
                lora_config.lora_extra_vocab_size,
                self.base_layer.embedding_dim,
            ),
            dtype=self.base_layer.weight.dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                self.base_layer.org_vocab_size +
                lora_config.lora_extra_vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
            lora_a, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
211
212
213
214
215
216
217
218
219
220
221
222
        
        self.lora_a = lora_a.to(self.device)
        self.lora_b = lora_b.to(self.device)

        if self.lora_config.merge_lora:
            merged_weights = torch.matmul(self.lora_a, self.lora_b)
            if merged_weights.shape != self.base_layer.weight.data:
                merged_weights = merged_weights.T + self.base_layer.weight
            else:
                merged_weights = merged_weights + self.base_layer.weight
            self.base_layer.weight.data.copy_(merged_weights)

223
224
225
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
226
                shape[1], ].copy_(embeddings_tensor, non_blocking=True)
227
228
229
230
231
232
            if self.embeddings_slice is not None:
                # TODO(yard1): Optimize this copy, we don't need to copy
                # everything, just the modified part
                embeddings = self.embeddings_tensors.view(
                    self.embeddings_tensors.shape[0] *
                    self.embeddings_tensors.shape[1],
233
                    self.embeddings_tensors.shape[2],
234
                )[self.embeddings_slice[0]:self.embeddings_slice[1]]
235
                assert self.embeddings_weights is not None
236
237
238
239
                self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        added_tokens_mask = x > self.base_layer.org_vocab_size - 1
240
        embeddings_indices = self.punica_wrapper.embeddings_indices            
241
        indices = embeddings_indices[0].view_as(x)
242
243
244
245
246
247

        if not self.lora_config.merge_lora:
            indices_0 = embeddings_indices[1].view_as(x)
            full_lora_a_embeddings = F.embedding(
                x + indices_0,
                self.lora_a_stacked_2d,
248
249
            )

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            full_output = self.base_layer.forward(
                x.add_(indices * added_tokens_mask))

            full_output_org = full_output
            if full_output.ndim == 3:
                full_output = full_output.view(
                    full_output.shape[0] * full_output.shape[1], -1)
            if full_lora_a_embeddings.ndim == 3:
                full_lora_a_embeddings = full_lora_a_embeddings.view(
                    full_lora_a_embeddings.shape[0] *
                    full_lora_a_embeddings.shape[1],
                    -1,
                )

            # Embedding layer only need expand op
            self.punica_wrapper.add_expand(full_output,
                                        full_lora_a_embeddings,
                                        self.lora_b_stacked,
                                        add_input=True)
            return full_output.view_as(full_output_org)
        else:
            full_output = self.base_layer.forward(
                x.add_(indices * added_tokens_mask))
            return full_output
274

275
    @classmethod
276
277
278
279
280
281
282
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
283
284
        return type(source_layer) is VocabParallelEmbedding

285

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.input_size = self.base_layer.input_size
        self.output_size = self.base_layer.output_size
        self.device = _get_lora_device(self.base_layer)

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
        self.lora_config = lora_config
        lora_a_output_size = lora_config.max_lora_rank
        self.lora_a_stacked = torch.zeros(
            max_loras,
            1,
            lora_a_output_size,
            self.input_size,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            max_loras,
            1,
            self.output_size,
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                     self.lora_b_stacked, 1.0)
        return output

    def forward(self, input_):
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return type(source_layer) is ReplicatedLinear


379
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
380
381
    """
    LoRA on top of ColumnParallelLinear layer.
382

383
384
    LoRA B is sliced for tensor parallelism.
    """
385
386
387
388

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
389
        self.tp_size = get_tensor_model_parallel_world_size()
390
391
392
        self.input_size = self.base_layer.input_size
        self.output_size = self.base_layer.output_size_per_partition
        self.device = _get_lora_device(self.base_layer)
393
394

    def create_lora_weights(
395
396
397
398
399
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
400
401
402
403
404
        self.lora_config = lora_config
        self.tp_size = get_tensor_model_parallel_world_size()
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
405
406
407
        self.lora_a_stacked = torch.zeros(
            max_loras,
            1,
408
            lora_a_output_size_per_partition,
409
            self.input_size,
410
            dtype=lora_config.lora_dtype,
411
            device=self.device,
412
413
414
415
        )
        self.lora_b_stacked = torch.zeros(
            max_loras,
            1,
416
            self.output_size,
417
418
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
419
            device=self.device,
420
        )
421
        self.output_dim = self.lora_b_stacked.shape[2]
422
423
424
425
426

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

427
428
429
430
431
432
433
434
435
436
437
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.output_dim
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        lora_b = lora_b[:, start_idx:end_idx]
        return lora_b

438
439
440
441
442
443
444
445
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
446

447
        if self.tp_size > 1:
448
449
450
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

451
452
453
454
455
456
457
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

458
459
460
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
461
462
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                     self.lora_b_stacked, 1.0)
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        return output

    def forward(self, input_):
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = (self.base_layer.bias
                if not self.base_layer.skip_bias_add else None)

        # Matrix multiply.
479
        output_parallel = self.apply(input_, bias)
480
481
482
483
484
485
486
487
488
        if self.base_layer.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = (self.base_layer.bias
                       if self.base_layer.skip_bias_add else None)
        return output, output_bias

489
    @classmethod
490
    @_not_fully_sharded_can_replace
491
492
493
494
495
496
497
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
498
499
500
501
        return type(source_layer) is ColumnParallelLinear or (
            type(source_layer) is MergedColumnParallelLinear
            and len(packed_modules_list) == 1)

502
503
504
505
506
507
508
509
510
511
512
513
514
515

class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (eg. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
516
517
518
519
520
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
521
        self.lora_config = lora_config
522
523
524
525
526
527
528
529
        n_slices = 2
        if not (len(self.base_layer.output_sizes) == n_slices
                and self.base_layer.output_sizes[0]
                == self.base_layer.output_sizes[1]):
            raise ValueError(
                "LoRAColumnParallelLinear2Slice requires 2 slices with "
                "the same size.")
        self.tp_size = get_tensor_model_parallel_world_size()
530
531
532
533
534
        self.tp_rank = get_tensor_model_parallel_rank()

        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
535
536
537
538
539

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
540
                lora_a_output_size_per_partition,
541
                self.input_size,
542
                dtype=lora_config.lora_dtype,
543
                device=self.device,
544
545
546
547
548
            ) for _ in range(n_slices))
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
549
                self.output_size // 2,
550
551
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
552
                device=self.device,
553
554
555
556
557
558
559
560
561
562
            ) for _ in range(n_slices))

        self.output_dim = self.lora_b_stacked[0].shape[2]

    def reset_lora(self, index: int):
        self.lora_a_stacked[0][index] = 0
        self.lora_a_stacked[1][index] = 0
        self.lora_b_stacked[0][index] = 0
        self.lora_b_stacked[1][index] = 0

563
564
565
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
566
567
        return lora_a

568
569
570
571
572
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
        if lora_b[0] is None or lora_b[1] is None:
            return lora_b
573
574
575
576
        shard_size = self.output_dim
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = [
577
578
            lora_b[0][:, start_idx:end_idx],
            lora_b[1][:, start_idx:end_idx],
579
580
581
        ]
        return lora_b

582
583
584
585
586
587
588
589
590
591
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
592
593
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609

        if lora_a[0] is not None:
            self.lora_a_stacked[0][
                index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                    lora_a[0].T, non_blocking=True)
            self.lora_b_stacked[0][
                index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
                    lora_b[0].T, non_blocking=True)
        if lora_a[1] is not None:
            self.lora_a_stacked[1][
                index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                    lora_a[1].T, non_blocking=True)
            self.lora_b_stacked[1][
                index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                    lora_b[1].T, non_blocking=True)

610
611
612
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
613
614
615
        self.punica_wrapper.add_lora_packed_nslice(
            output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
            (self.output_dim, self.output_dim))
616
617
        return output

618
    @classmethod
619
    @_not_fully_sharded_can_replace
620
621
622
623
624
625
626
627
628
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is MergedColumnParallelLinear
                and len(packed_modules_list) == 2)
629

630
631

class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
632
    """
633
634
635
    ColumnParallelLinear layer that is specifically designed for
    qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
    only contains a single LoRA within their qkv_proj layer.
636

637
    During inference with Tensor Parallel, the weights of lora_b
638
    must be accurately partitioned according to the respective ranks.
639

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.q_proj_total_size = (self.base_layer.total_num_heads *
                                  self.base_layer.head_size)
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
                                   self.base_layer.head_size)

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        tp_rank = get_tensor_model_parallel_rank()
        self.q_shard_id = tp_rank
        self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[:, self.q_proj_shard_size *
                          self.q_shard_id:self.q_proj_shard_size *
                          (self.q_shard_id + 1)]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[:, k_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:k_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[:, v_offset +
                          self.kv_proj_shard_size * self.kv_shard_id:v_offset +
                          self.kv_proj_shard_size * (self.kv_shard_id + 1)]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
        return lora_b

674
675
676
677
678
679
680
681
682
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        if self.tp_size > 1:
683
684
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
685
686
687
688
689
690
691
692
693

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

    @classmethod
694
    @_not_fully_sharded_can_replace
695
696
697
698
699
700
701
702
    def can_replace_layer(cls, source_layer: nn.Module,
                          lora_config: LoRAConfig, packed_modules_list: List,
                          model_config: Optional[PretrainedConfig]) -> bool:
        return type(source_layer) is QKVParallelLinear and len(
            packed_modules_list) == 1


class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)

    def create_lora_weights(
717
718
719
720
721
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
722
        self.lora_config = lora_config
723
        self.tp_size = get_tensor_model_parallel_world_size()
724
        self.tp_rank = get_tensor_model_parallel_rank()
725
726
727
728
        self.q_proj_shard_size = (self.base_layer.num_heads *
                                  self.base_layer.head_size)
        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
                                   self.base_layer.head_size)
729
730
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
731

732
733
734
        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size))
735
736
737
738
739
        # q, k, v
        self.lora_a_stacked = (
            torch.zeros(
                max_loras,
                1,
740
                lora_a_output_size_per_partition,
741
                self.input_size,
742
                dtype=lora_config.lora_dtype,
743
                device=self.device,
744
745
746
747
            ),
            torch.zeros(
                max_loras,
                1,
748
                lora_a_output_size_per_partition,
749
                self.input_size,
750
                dtype=lora_config.lora_dtype,
751
                device=self.device,
752
753
754
755
            ),
            torch.zeros(
                max_loras,
                1,
756
                lora_a_output_size_per_partition,
757
                self.input_size,
758
                dtype=lora_config.lora_dtype,
759
                device=self.device,
760
761
762
763
764
765
766
767
768
            ),
        )
        self.lora_b_stacked = (
            torch.zeros(
                max_loras,
                1,
                self.q_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
769
                device=self.device,
770
771
772
773
774
775
776
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
777
                device=self.device,
778
779
780
781
782
783
784
            ),
            torch.zeros(
                max_loras,
                1,
                self.kv_proj_shard_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
785
                device=self.device,
786
787
788
            ),
        )

789
790
791
792
793
        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
794
795
        self.packed_indices: Optional[torch.Tensor] = None
        self.standard_indices: Optional[torch.Tensor] = None
796
        # lazily initialized.
797
        self.indices: torch.Tensor
798
        self.indices_len: List[int]
799
800

    def reset_lora(self, index: int):
801
802
803
804
805
806
807
        if not self.lora_config.merge_lora:
            self.lora_a_stacked[0][index] = 0
            self.lora_b_stacked[0][index] = 0
            self.lora_a_stacked[1][index] = 0
            self.lora_b_stacked[1][index] = 0
            self.lora_a_stacked[2][index] = 0
            self.lora_b_stacked[2][index] = 0
808

809
810
811
    def slice_lora_a(
        self, lora_a: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
812
813
        return lora_a

814
815
816
817
    def slice_lora_b(
        self, lora_b: List[Union[torch.Tensor, None]]
    ) -> List[Union[torch.Tensor, None]]:
        lora_b_q, lora_b_k, lora_b_v = None, None, None
818
819
820
        if lora_b[0] is not None:
            lora_b_q = lora_b[0][:, self.q_proj_shard_size *
                                 self.q_shard_id:self.q_proj_shard_size *
821
                                 (self.q_shard_id + 1), ]
822
823
824
        if lora_b[1] is not None:
            lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
825
                                 (self.kv_shard_id + 1), ]
826
827
828
        if lora_b[2] is not None:
            lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
                                 self.kv_shard_id:self.kv_proj_shard_size *
829
                                 (self.kv_shard_id + 1), ]
830
831
832
        lora_b = [lora_b_q, lora_b_k, lora_b_v]
        return lora_b

833
834
835
836
837
838
839
840
841
842
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
843
844
845
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
        if self.lora_config.merge_lora:
            qkv_weights_list = []
            for i in range(len(self.output_slices)):
                if lora_a[i] is not None:
                    if lora_a[i].numel() == 0 or lora_b[i].numel() == 0:
                        continue
                    weight_A = lora_a[i].to(self.device)
                    weight_B = lora_b[i].to(self.device)
                    delta_weight = torch.matmul(weight_A, weight_B)
                    qkv_weights_list.append(delta_weight)
                else:
                    if i == 0:
                        qkv_weights_list.append(torch.zeros(self.input_size, self.q_proj_shard_size, 
                                                                dtype=self.base_layer.weight.dtype, device=self.device))
                    else:
                        qkv_weights_list.append(torch.zeros(self.input_size, self.kv_proj_shard_size, 
                                                                dtype=self.base_layer.weight.dtype, device=self.device))

            if len(qkv_weights_list) > 0:
                qkv_weights = torch.cat(qkv_weights_list, dim=-1)
                if qkv_weights.shape != self.base_layer.weight.shape:
                    qkv_weights = qkv_weights.T + self.base_layer.weight.data
                else:
                    qkv_weights = qkv_weights + self.base_layer.weight.data
                
                self.base_layer.weight.data.copy_(qkv_weights)
        else:
            if lora_b[0] is not None:
                lora_b_q = lora_b[0]
                self.lora_b_stacked[0][
                    index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
                        lora_b_q.T, non_blocking=True)
            if lora_b[1] is not None:
                lora_b_k = lora_b[1]
                self.lora_b_stacked[1][
                    index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
                        lora_b_k.T, non_blocking=True)
            if lora_b[2] is not None:
                lora_b_v = lora_b[2]
                self.lora_b_stacked[2][
                    index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
                        lora_b_v.T, non_blocking=True)

            if lora_a[0] is not None:
                self.lora_a_stacked[0][
                    index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
                        lora_a[0].T, non_blocking=True)
            if lora_a[1] is not None:
                self.lora_a_stacked[1][
                    index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
                        lora_a[1].T, non_blocking=True)
            if lora_a[2] is not None:
                self.lora_a_stacked[2][
                    index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
                        lora_a[2].T, non_blocking=True)
901

902
903
    def apply(self, x: torch.Tensor,
              bias: Optional[torch.Tensor]) -> torch.Tensor:
904
905
906
907
908
909
910
911
        if not self.lora_config.merge_lora:
            output = self.base_layer.quant_method.apply(self.base_layer, x, bias)            
            self.punica_wrapper.add_lora_packed_nslice(output, x,
                                                       self.lora_a_stacked,
                                                       self.lora_b_stacked, 1.0,
                                                       self.output_slices)
        else:
            output = self.base_layer.quant_method.apply(self.base_layer, x, bias)          
912
913
        return output

914
    @classmethod
915
    @_not_fully_sharded_can_replace
916
917
918
919
920
921
922
923
924
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
        return (type(source_layer) is QKVParallelLinear
                and len(packed_modules_list) == 3)
925

926
927
928
929
930
931

class RowParallelLinearWithLoRA(BaseLayerWithLoRA):

    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__()
        self.base_layer = base_layer
932
933
934
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        self.device = _get_lora_device(self.base_layer)
935
936

    def create_lora_weights(
937
938
939
940
941
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
942
943
        self.lora_config = lora_config
        self.tp_rank = get_tensor_model_parallel_rank()
944
945
946
947
948
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
949
                self.input_size,
950
951
            ),
            dtype=lora_config.lora_dtype,
952
            device=self.device,
953
        )
954
955
956
957
958
        tp_size = get_tensor_model_parallel_world_size()
        lora_b_output_size_per_partition = (
            self.output_size if not lora_config.fully_sharded_loras else
            divide(self.output_size, tp_size))

959
960
961
962
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
963
                lora_b_output_size_per_partition,
964
965
966
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
967
            device=self.device,
968
969
970
971
972
973
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

974
975
976
977
978
979
980
981
982
983
984
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
        shard_size = self.input_size
        start_idx = tensor_model_parallel_rank * shard_size
        end_idx = (tensor_model_parallel_rank + 1) * shard_size
        lora_a = lora_a[start_idx:end_idx, :]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

985
986
987
988
989
990
991
992
    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
993

994
        if self.base_layer.tp_size > 1:
995
996
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)
997
998
999
1000
1001
1002
1003
1004

        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)

1005
1006
    def apply(self, x: torch.Tensor) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x)
1007
1008
        self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                     self.lora_b_stacked, 1.0)
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
        return output

    def forward(self, input_):
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # Set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.base_layer.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1034
        output_parallel = self.apply(input_parallel)
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
        if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.base_layer.skip_bias_add:
            output = (output_ + self.base_layer.bias
                      if self.base_layer.bias is not None else output_)
            output_bias = None
        else:
            output = output_
            output_bias = self.base_layer.bias
        return output, output_bias

    @property
    def weight(self):
1051
1052
        return (self.base_layer.weight if hasattr(self.base_layer, "weight")
                else self.base_layer.qweight)
1053

1054
    @classmethod
1055
    @_not_fully_sharded_can_replace
1056
1057
1058
1059
1060
1061
1062
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1063
1064
        return type(source_layer) is RowParallelLinear

1065

1066
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """
1080

1081
1082
1083
    def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
                 dtype: torch.dtype, device: torch.device,
                 sharded_to_full_mapping: Optional[List[int]]) -> None:
1084
1085
1086
1087
1088
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
1089
1090
1091
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping
1092

1093
    @property
1094
1095
    def logits_as_input(self):
        return self.base_layer.logits_as_input
1096

1097
1098
1099
1100
    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

1101
1102
1103
1104
    @property
    def scale(self):
        return self.base_layer.scale

Woosuk Kwon's avatar
Woosuk Kwon committed
1105
1106
1107
1108
    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

1109
1110
1111
1112
    @property
    def use_gather(self):
        return self.base_layer.use_gather

1113
1114
1115
1116
1117
1118
1119
1120
    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

1121
1122
1123
1124
    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

1125
1126
1127
1128
1129
1130
    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1131
1132
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
1133
            raise ValueError("When using LoRA, vocab size must be "
1134
                             "32000 >= vocab_size <= 257024")
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                # Pad for kernel compatibility
                math.ceil(self.base_layer.vocab_size /
                          lora_config.lora_vocab_padding_size) *
                lora_config.lora_vocab_padding_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.embeddings_tensors = torch.full(
            (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
            fill_value=float("-inf"),
            dtype=self.dtype,
            device=self.device,
        )
1164
1165
1166
1167
1168
1169
1170
        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping,
                device=self.device,
                dtype=torch.long)
        else:
            self.sharded_to_full_mapping_gpu = None
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0
        self.embeddings_tensors[index] = float("-inf")

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        self.reset_lora(index)
        self.lora_a_stacked[index,
                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                lora_a.T, non_blocking=True)
        self.lora_b_stacked[index,
                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                lora_b.T, non_blocking=True)
        if embeddings_tensor is not None:
            self.embeddings_tensors[
                index, :embeddings_tensor.shape[0], :embeddings_tensor.
                shape[1], ] = embeddings_tensor

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
1199
        lm_head: VocabParallelEmbedding,
1200
        embedding_bias: Optional[torch.Tensor] = None,
1201
    ) -> Optional[torch.Tensor]:
1202
        # Get the logits for the next tokens.
1203
        logits = lm_head.linear_method.apply(lm_head, hidden_states)
1204
1205
1206
1207
1208
1209
        if embedding_bias is not None:
            logits += embedding_bias
        logits = tensor_model_parallel_gather(logits)
        if logits is None:
            return None

1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
        lora_logits = torch.empty(
            self.embeddings_tensors.shape[0] + 1,
            self.embeddings_tensors.shape[1],
            hidden_states.shape[0],
            dtype=self.embeddings_tensors.dtype,
            device=self.embeddings_tensors.device,
        )
        torch.matmul(self.embeddings_tensors,
                     hidden_states.T,
                     out=lora_logits[:-1])
        lora_logits[-1] = float("-inf")
        lora_logits = lora_logits.mT
1241
        indices_padded = self.punica_wrapper.sampler_indices_padded
1242
1243
1244
        lora_logits = (lora_logits.reshape(
            lora_logits.shape[0] * lora_logits.shape[1],
            lora_logits.shape[2],
1245
1246
1247
        ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
                                                      posinf=float("inf"),
                                                      neginf=float("-inf")))
1248
1249
        logits[:,
               self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
1250
1251
1252
1253
1254
1255
               lora_logits.shape[1], ] = lora_logits

        # LogitsProcessorWithLoRA always using bgmv
        self.punica_wrapper.add_lora_logits(logits, hidden_states,
                                            self.lora_a_stacked,
                                            self.lora_b_stacked, 1.0)
1256
1257
1258
1259
1260
1261
1262
1263

        # Remove paddings in vocab (if any).
        logits = logits[:, :self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

1264
    @classmethod
1265
1266
1267
1268
1269
1270
1271
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1272
1273
        # Special handling for the LogitsProcessor.
        return False
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301


class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
    """Implements RoPE-scaled embeddings with linear scaling for
    multiple LoRA adapters with a specialized kernel.

    Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
    which can handle multi lora adapters in a specialied kernel.
    """

    def __init__(self, base_layer: RotaryEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer

    @property
    def scaling_factors(self):
        return self.base_layer.scaling_factors

    @property
    def rotary_dim(self):
        return self.base_layer.rotary_dim

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: Optional[PretrainedConfig] = None,
    ) -> None:
1302
1303
        scaling_factors = (list(lora_config.long_lora_scaling_factors)
                           if lora_config.long_lora_scaling_factors else [])
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
        base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
            self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
        scaling_factors = sorted(
            list(set([base_scaling_factor] + scaling_factors)))
        self.base_layer = LinearScalingRotaryEmbedding(
            self.base_layer.head_size,
            self.base_layer.rotary_dim,
            self.base_layer.max_position_embeddings,
            self.base_layer.base,
            self.base_layer.is_neox_style,
            scaling_factors,
            self.base_layer.dtype,
        )

    def reset_lora(self, index: int):
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
        embeddings_tensor: Optional[torch.Tensor],
    ):
        ...

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.base_layer(
            positions,
            query,
            key,
1340
1341
            offsets=self.punica_wrapper.long_lora_indices,
        )
1342
1343
1344
1345
1346
1347

    @property
    def scaling_factor_to_offset(self) -> Dict[float, int]:
        return self.base_layer.scaling_factor_to_offset

    @classmethod
1348
1349
1350
1351
1352
1353
1354
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: List,
        model_config: Optional[PretrainedConfig],
    ) -> bool:
1355
        """Returns True if the layer can be replaced by this LoRA layer."""
1356
1357
        return (type(source_layer) is LinearScalingRotaryEmbedding
                or type(source_layer) is RotaryEmbedding)
1358
1359
1360

    def extra_repr(self) -> str:
        return self.base_layer.extra_repr()